diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 81acb9deacd..c286b8d3b52 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -334,8 +334,8 @@ public static boolean executeScript( String[] args ) LineageCacheConfig.setCachePolicy(LINEAGE_POLICY); LineageCacheConfig.setEstimator(LINEAGE_ESTIMATE); - if (dmlOptions.oocLogEvents) - OOCEventLog.setup(100000); + if(dmlOptions.oocLogEvents) + OOCEventLog.setup(1000000); String dmlScriptStr = readDMLScript(isFile, fileOrScript); Map argVals = dmlOptions.argVals; diff --git a/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java index 0d15218bf9a..e745887d088 100644 --- a/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java +++ b/src/main/java/org/apache/sysds/api/ScriptExecutorUtils.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.gpu.context.GPUContext; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysds.runtime.instructions.gpu.context.GPUObject; @@ -45,7 +46,7 @@ public class ScriptExecutorUtils { * Execute the runtime program. This involves execution of the program * blocks that make up the runtime program and may involve dynamic * recompilation. - * + * * @param se * script executor * @param statisticsMaxHeavyHitters @@ -62,7 +63,7 @@ public static void executeRuntimeProgram(ScriptExecutor se, int statisticsMaxHea * Execute the runtime program. This involves execution of the program * blocks that make up the runtime program and may involve dynamic * recompilation. - * + * * @param rtprog * runtime program * @param ec @@ -82,7 +83,7 @@ public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DM List gCtxs = GPUContextPool.reserveAllGPUContexts(); if (gCtxs == null) { throw new DMLRuntimeException( - "GPU : Could not create GPUContext, either no GPU or all GPUs currently in use"); + "GPU : Could not create GPUContext, either no GPU or all GPUs currently in use"); } gCtxs.get(0).initializeThread(); ec.setGPUContexts(gCtxs); @@ -120,18 +121,21 @@ public static void executeRuntimeProgram(Program rtprog, ExecutionContext ec, DM } if( ConfigurationManager.isCodegenEnabled() ) SpoofCompiler.cleanupCodeGenerator(); - + // display statistics (incl caching stats if enabled) Statistics.stopRunTimer(); System.out.println(Statistics.display(statisticsMaxHeavyHitters > 0 ? - statisticsMaxHeavyHitters : DMLScript.STATISTICS_COUNT)); - + statisticsMaxHeavyHitters : DMLScript.STATISTICS_COUNT)); + if (DMLScript.LINEAGE_ESTIMATE) System.out.println(LineageEstimatorStatistics.displayLineageEstimates()); - if (DMLScript.USE_OOC) + if(DMLScript.USE_OOC) { + // Clean symbol-table entries so OOC streams count as de-referenced + if((outputVariables == null || outputVariables.isEmpty()) && ec != null) + ec.getVarList().forEach(var -> VariableCPInstruction.processRmvarInstruction(ec, var)); OOCCacheManager.reset(); + } } } - } diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 34da36dd13c..b3475edfbae 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -158,7 +158,7 @@ else if(_op == OpOp1.MEDIAN) else { // general case MATRIX ExecType et = optFindExecType(); // special handling cumsum/cumprod/cummin/cumsum - if(isCumulativeUnaryOperation() && !(et == ExecType.CP || et == ExecType.GPU)) { + if(isCumulativeUnaryOperation() && !(et == ExecType.CP || et == ExecType.GPU || et == ExecType.OOC)) { // TODO additional physical operation if offsets fit in memory ret = constructLopsSparkCumulativeUnary(); } diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index 7bfea43c2e5..47e57610965 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -882,7 +882,8 @@ public String prepScalarOperand(ExecType et, String label) { boolean isLiteral = (isData && ((Data)this).isLiteral()); StringBuilder sb = new StringBuilder(""); - if ( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.GPU || (isData && isLiteral)) { + if ( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.GPU || et == ExecType.OOC + || (isData && isLiteral)) { sb.append(label); } else { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index 85f41ad7fe2..793cf39ce69 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -27,6 +27,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -729,12 +730,19 @@ private void executeLocalParFor( ExecutionContext ec, IntObject from, IntObject final LocalTaskQueue queue = new LocalTaskQueue<>(); final Thread[] threads = new Thread[_numThreads]; final LocalParWorker[] workers = new LocalParWorker[_numThreads]; + @SuppressWarnings("unchecked") + final HashMap[] workerBaselines = DMLScript.USE_OOC ? new HashMap[_numThreads] : null; try { // Step 1) create task queue and init workers in parallel // (including preparation of update-in-place variables) IntStream.range(0, _numThreads).forEach(i -> { workers[i] = createParallelWorker( _pwIDs[i], queue, ec, i); + if(DMLScript.USE_OOC) { + workerBaselines[i] = new HashMap<>(); + for(Map.Entry e : workers[i].getVariables().entrySet()) + workerBaselines[i].put(e.getKey(), e.getValue()); + } threads[i] = new Thread( workers[i] , "PARFOR"); threads[i].setPriority(Thread.MAX_PRIORITY); }); @@ -777,11 +785,21 @@ private void executeLocalParFor( ExecutionContext ec, IntObject from, IntObject // Step 4) collecting results from each parallel worker //obtain results and cleanup other intermediates before result merge - LocalVariableMap [] localVariables = new LocalVariableMap [_numThreads]; + Set resultVarNames = _resultVars.stream() + .map(v -> v._name).collect(Collectors.toSet()); + LocalVariableMap [] localVariables = new LocalVariableMap [_numThreads]; for( int i=0; i<_numThreads; i++ ) { localVariables[i] = workers[i].getVariables(); - localVariables[i].removeAllNotIn(_resultVars.stream() - .map(v -> v._name).collect(Collectors.toSet())); + if(DMLScript.USE_OOC) { + for(String var : localVariables[i].keySet()) { + if(!resultVarNames.contains(var)) { + Data current = localVariables[i].get(var); + if(current != null && current != workerBaselines[i].get(var)) + VariableCPInstruction.processRmvarInstruction(workers[i].getExecutionContext(), var); + } + } + } + localVariables[i].removeAllNotIn(resultVarNames); numExecutedTasks += workers[i].getExecutedTasks(); numExecutedIterations += workers[i].getExecutedIterations(); } @@ -797,6 +815,20 @@ private void executeLocalParFor( ExecutionContext ec, IntObject from, IntObject //consolidate results into global symbol table consolidateAndCheckResults( ec, numIterations, numCreatedTasks, numExecutedIterations, numExecutedTasks, localVariables ); + + if(DMLScript.USE_OOC) { + // Cleanup remaining variables + for(int i = 0; i < workers.length; i++) { + if(workers[i].getExecutedTasks() <= 0) + continue; + for(String var : localVariables[i].keySet()) { + Data current = localVariables[i].get(var); + if(current != null && current != workerBaselines[i].get(var)) + VariableCPInstruction.processRmvarInstruction(workers[i].getExecutionContext(), var); + } + } + + } // Step 5) cleanup local parworkers (e.g., remove created functions) for( int i=0; i<_numThreads; i++ ) { @@ -1392,11 +1424,19 @@ private void consolidateAndCheckResults(ExecutionContext ec, final long expIters CacheableData outNew = USE_PARALLEL_RESULT_MERGE ? rm.executeParallelMerge(_numThreads) : rm.executeSerialMerge(); - - //cleanup existing var - Data exdata = ec.removeVariable(var._name); - if( exdata != null && exdata != outNew ) - ec.cleanupDataObject(exdata); + + if(DMLScript.USE_OOC) { + //cleanup existing var with rmvar semantics to keep OOC ref counters consistent + Data exdata = ec.getVariable(var._name); + if( exdata != null && exdata != outNew ) + VariableCPInstruction.processRmvarInstruction(ec, var._name); + } + else { + //cleanup existing var + Data exdata = ec.removeVariable(var._name); + if( exdata != null && exdata != outNew ) + ec.cleanupDataObject(exdata); + } //cleanup of intermediate result variables cleanWorkerResultVariables( ec, out, in, true ); @@ -1607,6 +1647,11 @@ public void run() outNew = rm.executeSerialMerge(); synchronized( _ec.getVariables() ){ + if(DMLScript.USE_OOC) { + Data exdata = _ec.getVariable(var._name); + if(exdata != null && exdata != outNew) + VariableCPInstruction.processRmvarInstruction(_ec, var._name); + } _ec.getVariables().put( var._name, outNew); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index aee08516db6..75e3fc885e7 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory; +import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.lineage.LineageCache; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; @@ -293,8 +294,13 @@ protected UpdateType[] prepareUpdateInPlaceVariables(ExecutionContext ec, long t moNew.setFileName(mo.getFileName() + Lop.UPDATE_INPLACE_PREFIX + tid); mo.release(); // cleanup old variable (e.g., remove from buffer pool) - if(ec.removeVariable(varname) != null) + if(DMLScript.USE_OOC) { + if(ec.containsVariable(varname)) + VariableCPInstruction.processRmvarInstruction(ec, varname); + } + else if(ec.removeVariable(varname) != null) { ec.cleanupCacheableData(mo); + } moNew.release(); // after old removal to avoid unnecessary evictions moNew.setUpdateType(UpdateType.INPLACE); ec.setVariable(varname, moNew); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java index 39e3fe488c1..47a854bd2ec 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java @@ -68,7 +68,8 @@ public double getLeafNodeEstimate(TestMeasure measure, OptNode node) value = DEFAULT_MEM_REMOTE + h.getSpBroadcastSize(); } //check for invalid cp memory estimate - else if ( h.getExecType()==ExecType.CP && value >= OptimizerUtils.getLocalMemBudget() ) { + else if ( (h.getExecType()==ExecType.CP || h.getExecType()==ExecType.OOC) + && value >= OptimizerUtils.getLocalMemBudget() ) { if( !forcedExec && !HopRewriteUtils.hasListInputs(h) ) LOG.warn("Memory estimate larger than budget but CP exec type (op="+h.getOpString()+", name="+h.getName()+", memest="+h.getMemEstimate()+")."); value = DEFAULT_MEM_REMOTE; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java index 4dd2249654c..dd52c9d3775 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java @@ -317,7 +317,7 @@ public static List rCreateAbstractOptNodes(Hop hop, LocalVariableMap va Types.ExecType et = (hop.getExecType()!=null) ? hop.getExecType() : Types.ExecType.CP; switch( et ) { - case CP:case GPU: + case CP:case GPU:case OOC: node.setExecType(ExecType.CP); break; case SPARK: node.setExecType(ExecType.SPARK); break; @@ -329,7 +329,7 @@ public static List rCreateAbstractOptNodes(Hop hop, LocalVariableMap va } //handle degree of parallelism - if( et == Types.ExecType.CP && hop instanceof MultiThreadedHop ){ + if( (et == Types.ExecType.CP || et == Types.ExecType.OOC) && hop instanceof MultiThreadedHop ){ MultiThreadedHop mtop = (MultiThreadedHop) hop; node.setK( OptimizerUtils.getConstrainedNumThreads(mtop.getMaxNumThreads()) ); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 607acbb3a0c..21816e90ad8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -94,6 +94,8 @@ else if(parts.length == 4) return TSMMOOCInstruction.parseInstruction(str); case Reorg: return ReorgOOCInstruction.parseInstruction(str); + case Reshape: + return ReorgOOCInstruction.parseInstruction(str); case Tee: return TeeOOCInstruction.parseInstruction(str); case CentralMoment: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java index b3f6a6a117c..12caa3c61ba 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -220,19 +220,29 @@ public void processInstruction(ExecutionContext ec) { throw new DMLRuntimeException("error executing function " + fname, e); } long t1 = !ReuseCacheType.isNone()||DMLScript.LINEAGE_ESTIMATE ? System.nanoTime() : 0; - - // cleanup all returned variables w/o binding + + // cleanup all returned variables w/o binding HashSet expectRetVars = new HashSet<>(); for(DataIdentifier di : fpb.getOutputParams()) expectRetVars.add(di.getName()); - + LocalVariableMap retVars = fn_ec.getVariables(); - for( String varName : new ArrayList<>(retVars.keySet()) ) { - if( expectRetVars.contains(varName) ) - continue; - //cleanup unexpected return values to avoid leaks - //(including OOC reference tracking for matrix streams) - VariableCPInstruction.processRmvarInstruction(fn_ec, varName); + if(DMLScript.USE_OOC) { + for( String varName : new ArrayList<>(retVars.keySet()) ) { + if( expectRetVars.contains(varName) ) + continue; + // cleanup unexpected return values to avoid leaks + // (including OOC reference tracking for matrix streams) + VariableCPInstruction.processRmvarInstruction(fn_ec, varName); + } + } + else { + for( String varName : new ArrayList<>(retVars.keySet()) ) { + if( expectRetVars.contains(varName) ) + continue; + // cleanup unexpected return values to avoid leaks + fn_ec.cleanupDataObject(fn_ec.removeVariable(varName)); + } } // Unpin the pinned variables @@ -245,7 +255,8 @@ public void processInstruction(ExecutionContext ec) { for (int i=0; i< numOutputs; i++) { String boundVarName = _boundOutputNames.get(i); String retVarName = fpb.getOutputParams().get(i).getName(); - Data boundValue = retVars.get(retVarName); + Data boundValue = DMLScript.USE_OOC ? + fn_ec.removeVariable(retVarName) : retVars.get(retVarName); if (boundValue == null) throw new DMLRuntimeException("fcall "+_functionName+": " +boundVarName + " was not assigned a return value"); @@ -288,11 +299,17 @@ public void processInstruction(ExecutionContext ec) { //FIXME: send _boundOutputNames instead of fpb.getOutputParams as //those are already replaced by boundoutput names in the lineage map. } + + if(DMLScript.USE_OOC) { + // Cleanup any remaining unbound outputs in function scope. + for( String varName : new ArrayList<>(fn_ec.getVariables().keySet()) ) + VariableCPInstruction.processRmvarInstruction(fn_ec, varName); - // cleanup declared outputs that are not bound at callsite - for (int i = numOutputs; i < fpb.getOutputParams().size(); i++) { - String retVarName = fpb.getOutputParams().get(i).getName(); - VariableCPInstruction.processRmvarInstruction(fn_ec, retVarName); + // cleanup declared outputs that are not bound at callsite + for (int i = numOutputs; i < fpb.getOutputParams().size(); i++) { + String retVarName = fpb.getOutputParams().get(i).getName(); + VariableCPInstruction.processRmvarInstruction(fn_ec, retVarName); + } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 9b161ea99d9..359df747e7b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -1036,6 +1036,8 @@ private void processCopyInstruction(ExecutionContext ec) { // remove existing variable bound to target name Data input2_data = ec.removeVariable(getInput2().getName()); + if (DMLScript.USE_OOC && input2_data instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject) input2_data).getStreamable(), -1); //cleanup matrix data on fs/hdfs (if necessary) if( input2_data != null ) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index db8091621da..f0d4fd29af7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -162,7 +162,7 @@ public void processInstruction( ExecutionContext ec ) { MatrixBlock ltmp; int extra = _aop.correction.getNumRemovedRowsColumns(); - MatrixBlock ret = new MatrixBlock(1,1+extra,false); + MatrixBlock ret = new MatrixBlock(1, 1 + extra, _aop.initialValue); MatrixBlock corr = new MatrixBlock(1,1+extra,false); while((ltmp = qLocal.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { OperationsOnMatrixValues.incrementalAggregation( diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index e45b8e93bc6..1352e7ff9c7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -20,7 +20,6 @@ package org.apache.sysds.runtime.instructions.ooc; import org.apache.commons.lang3.NotImplementedException; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -75,8 +74,26 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) { qIn2.messageUpstream(msg.split()); }); - if (m1.getNumRows() < 0 || m1.getNumColumns() < 0 || m2.getNumRows() < 0 || m2.getNumColumns() < 0) - throw new DMLRuntimeException("Cannot process (matrix, matrix) BinaryOOCInstruction with unknown dimensions."); + final boolean known1 = (m1.getNumRows() >= 0 && m1.getNumColumns() >= 0); + final boolean known2 = (m2.getNumRows() >= 0 && m2.getNumColumns() >= 0); + + // If dimensions are unknown, we cannot safely detect broadcasting. + // Fall back to strict key-based join and let downstream operators validate as needed. + if(!known1 || !known2) { + if(LOG.isWarnEnabled()) { + LOG.warn("Falling back to key-wise OOC binary join for opcode '" + getOpcode() + + "' due to unknown matrix dimensions: " + input1.getName() + "=" + m1.getNumRows() + "x" + + m1.getNumColumns() + ", " + input2.getName() + "=" + m2.getNumRows() + "x" + + m2.getNumColumns()); + } + joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp1.getIndexes(), + tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(), tmpOut.getValue())); + return tmpOut; + }, IndexedMatrixValue::getIndexes); + return; + } boolean isColBroadcast = m1.getNumColumns() > 1 && m2.getNumColumns() == 1; boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index 7e1bdac73d2..d8eac50fcb4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.instructions.ooc; -import org.apache.commons.collections4.BidiMap; -import org.apache.commons.collections4.bidimap.DualHashBidiMap; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; @@ -39,9 +37,11 @@ import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; import java.util.ArrayList; -import java.util.List; import java.util.BitSet; +import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.function.BiFunction; @@ -77,7 +77,7 @@ public class CachingStream implements OOCStreamable { // state flags private boolean _cacheInProgress = true; // caching in progress, in the first pass. - private BidiMap _index; + private volatile Map _index; private DMLRuntimeException _failure; @@ -96,7 +96,7 @@ public CachingStream(OOCStream source, long streamId) { if(OOCWatchdog.WATCH) { _watchdogId = "CS-" + hashCode(); // Capture a short context to help identify origin - OOCWatchdog.registerOpen(_watchdogId, "CachingStream@" + hashCode(), getCtxMsg(), this); + OOCWatchdog.registerOpen(_watchdogId, toString(), getCtxMsg(), this); } _downstreamRelays = null; source.setSubscriber(tmp -> { @@ -111,119 +111,126 @@ public CachingStream(OOCStream source, long streamId) { if(!tmp.isEos()) { if(!_cacheInProgress) throw new DMLRuntimeException("Stream is closed"); - if(tmp instanceof OOCStream.GroupQueueCallback) { - @SuppressWarnings("unchecked") - OOCStream.GroupQueueCallback group = - (OOCStream.GroupQueueCallback) tmp; - groupSize = group.size(); + + if(tmp instanceof OOCStream.GroupQueueCallback) { + OOCStream.GroupQueueCallback group = + (OOCStream.GroupQueueCallback) tmp; + groupSize = group.size(); + for(int gi = 0; gi < groupSize; gi++) { + OOCStream.QueueCallback sub = group.getCallback(gi); + try(sub) { + IndexedMatrixValue imv = sub.get(); + if(_index != null) + _index.put(imv.getIndexes(), _numBlocks + gi); + } + } + + BlockKey baseKey; + boolean ownsEntry = true; + if(tmp instanceof OOCCacheManager.CachedGroupCallback cachedGroup) { + baseKey = cachedGroup.getBlockKey(); + ownsEntry = false; + if(mSubscribers != null && mSubscribers.length > 0) + mCallback = tmp.keepOpen(); + } + else { + List values = new ArrayList<>(groupSize); + long totalSize = 0; for(int gi = 0; gi < groupSize; gi++) { OOCStream.QueueCallback sub = group.getCallback(gi); try(sub) { IndexedMatrixValue imv = sub.get(); - if(_index != null) - _index.put(imv.getIndexes(), _numBlocks + gi); + values.add(imv); + totalSize += ((MatrixBlock) imv.getValue()).getExactSerializedSize(); } } - BlockKey baseKey; - boolean ownsEntry = true; - if(tmp instanceof OOCCacheManager.CachedGroupCallback cachedGroup) { - baseKey = cachedGroup.getBlockKey(); - ownsEntry = false; - if(mSubscribers != null && mSubscribers.length > 0) - mCallback = tmp.keepOpen(); + baseKey = new BlockKey(_streamId, _numBlocks); + if (_source instanceof SourceOOCStream && tmp instanceof SourceOOCStream.SourceGroupCallback sg) { + OOCIOHandler.GroupSourceBlockDescriptor gdesc = sg.getDescriptor(); + if (mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putRawSourceBacked(baseKey, values, totalSize, gdesc); + else + mCallback = OOCCacheManager.putAndPinRawSourceBacked(baseKey, values, totalSize, gdesc); } else { - List values = new ArrayList<>(groupSize); - long totalSize = 0; - for(int gi = 0; gi < groupSize; gi++) { - OOCStream.QueueCallback sub = group.getCallback(gi); - try(sub) { - IndexedMatrixValue imv = sub.get(); - values.add(imv); - totalSize += ((MatrixBlock) imv.getValue()).getExactSerializedSize(); - } - } - - baseKey = new BlockKey(_streamId, _numBlocks); - if (_source instanceof SourceOOCStream && tmp instanceof SourceOOCStream.SourceGroupCallback sg) { - OOCIOHandler.GroupSourceBlockDescriptor gdesc = sg.getDescriptor(); - if (mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.putRawSourceBacked(baseKey, values, totalSize, gdesc); - else - mCallback = OOCCacheManager.putAndPinRawSourceBacked(baseKey, values, totalSize, gdesc); - } - else { - if(mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.putRaw(baseKey, values, totalSize); - else - mCallback = OOCCacheManager.putAndPinRaw(baseKey, values, totalSize); - } + if(mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putRaw(baseKey, values, totalSize); + else + mCallback = OOCCacheManager.putAndPinRaw(baseKey, values, totalSize); } + } - blk = _numBlocks; - _numBlocks += groupSize; - for(int gi = 0; gi < groupSize; gi++) { - registerCacheKey(blk + gi, - new GroupedBlockKey(baseKey.getStreamId(), (int) baseKey.getSequenceNumber(), gi), - ownsEntry); - _consumptionCounts.add(0); - _groupIndices.add(gi); - _groupSizes.add(groupSize); - } + blk = _numBlocks; + _numBlocks += groupSize; + for(int gi = 0; gi < groupSize; gi++) { + registerCacheKey(blk + gi, + new GroupedBlockKey(baseKey.getStreamId(), (int) baseKey.getSequenceNumber(), gi), + ownsEntry); + _consumptionCounts.add(0); + _groupIndices.add(gi); + _groupSizes.add(groupSize); + if(_deletable) + tryDeleteBlock(blk + gi); + } + } + else { + final IndexedMatrixValue task = tmp.get(); + OOCIOHandler.SourceBlockDescriptor descriptor = null; + BlockKey blockKey = null; + boolean ownsEntry = true; + + if(tmp instanceof OOCCacheManager.CachedQueueCallback cachedQueue) { + blockKey = cachedQueue.getBlockKey(); + ownsEntry = false; + if(mSubscribers != null && mSubscribers.length > 0) + mCallback = tmp.keepOpen(); + } + else if(tmp instanceof OOCCacheManager.CachedSubCallback cachedSub) { + BlockKey parent = cachedSub.getParent().getBlockKey(); + blockKey = new GroupedBlockKey(parent.getStreamId(), (int) parent.getSequenceNumber(), + cachedSub.getGroupIndex()); + ownsEntry = false; + if(mSubscribers != null && mSubscribers.length > 0) + mCallback = tmp.keepOpen(); } - else { - final IndexedMatrixValue task = tmp.get(); - OOCIOHandler.SourceBlockDescriptor descriptor = null; - BlockKey blockKey = null; - boolean ownsEntry = true; - - if(tmp instanceof OOCCacheManager.CachedQueueCallback cachedQueue) { - blockKey = cachedQueue.getBlockKey(); - ownsEntry = false; - if(mSubscribers != null && mSubscribers.length > 0) - mCallback = tmp.keepOpen(); - } - else if(tmp instanceof OOCCacheManager.CachedSubCallback cachedSub) { - BlockKey parent = cachedSub.getParent().getBlockKey(); - blockKey = new GroupedBlockKey(parent.getStreamId(), (int) parent.getSequenceNumber(), - cachedSub.getGroupIndex()); - ownsEntry = false; - if(mSubscribers != null && mSubscribers.length > 0) - mCallback = tmp.keepOpen(); - } - if(_source instanceof SourceOOCStream src) { - descriptor = src.getDescriptor(task.getIndexes()); - } + if(_source instanceof SourceOOCStream src) { + descriptor = src.getDescriptor(task.getIndexes()); + } - if(blockKey == null) { - ownsEntry = true; - blockKey = new BlockKey(_streamId, _numBlocks); - if(descriptor == null) { - if(mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.put(_streamId, _numBlocks, task); - else - mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); - } - else { - if(mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor); - else - mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, - descriptor); - } + if(blockKey == null) { + ownsEntry = true; + blockKey = new BlockKey(_streamId, _numBlocks); + if(descriptor == null) { + if(mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.put(_streamId, _numBlocks, task); + else + mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); + } + else { + if(mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor); + else + mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, + descriptor); } - if(_index != null) - _index.put(task.getIndexes(), _numBlocks); - blk = _numBlocks; - _numBlocks++; - registerCacheKey(blk, blockKey, ownsEntry); - _consumptionCounts.add(0); - _groupIndices.add(-1); - _groupSizes.add(1); } + if(_index != null) + _index.put(task.getIndexes(), _numBlocks); + + blk = _numBlocks; + _numBlocks++; + registerCacheKey(blk, blockKey, ownsEntry); + _consumptionCounts.add(0); + _groupIndices.add(-1); + _groupSizes.add(1); + + if(_deletable) + tryDeleteBlock(blk); + } + notifyAll(); } else { @@ -308,7 +315,7 @@ public synchronized void scheduleDeletion() { return; // Deletion already scheduled if (_cacheInProgress && _maxConsumptionCount == 0) - throw new DMLRuntimeException("Cannot have a caching stream with no listeners"); + System.out.println("[WARN] Scheduling deletion for caching stream with no listeners: " + this); _deletable = true; for (int i = 0; i < _consumptionCounts.size(); i++) { @@ -326,13 +333,22 @@ private synchronized void tryDeleteBlock(int i) { throw new DMLRuntimeException("Cannot have more than " + _maxConsumptionCount + " consumptions."); if(!_ownedCacheKeys.get(i)) return; + + int groupIdx = _groupIndices.getInt(i); + int groupSize = _groupSizes.getInt(i); + if (groupIdx > 0 && groupSize > 1) { + // Grouped entries are physically represented by a single base cache entry. + // Re-check the base when a member reaches its terminal count. + if (cnt == _maxConsumptionCount) { + int baseId = i - groupIdx; + tryDeleteBlock(baseId); + } + return; + } + if (cnt == _maxConsumptionCount) { - int groupIdx = _groupIndices.getInt(i); - int groupSize = _groupSizes.getInt(i); if (groupIdx >= 0 && groupSize > 1) { int baseId = i - groupIdx; - if (i != baseId) - return; for (int j = 0; j < groupSize; j++) { if (_consumptionCounts.getInt(baseId + j) < _maxConsumptionCount) return; @@ -396,7 +412,7 @@ else if(!_cacheInProgress) { } } - public synchronized int findCachedIndex(MatrixIndexes idx) { + public int findCachedIndex(MatrixIndexes idx) { return _index.get(idx); } @@ -465,10 +481,7 @@ private void validateBlockCountOnClose() { * Finds a cached item asynchronously without counting it as a consumption. */ public void peekCachedAsync(MatrixIndexes idx, Consumer> callback) { - int mIdx; - synchronized(this) { - mIdx = _index.get(idx); - } + int mIdx = _index.get(idx); OOCCacheManager.requestBlock(getBlockKey(mIdx)).whenComplete((cb, r) -> callback.accept(cb)); } @@ -476,10 +489,7 @@ public void peekCachedAsync(MatrixIndexes idx, Consumer peekCached(MatrixIndexes idx) { - int mIdx; - synchronized(this) { - mIdx = _index.get(idx); - } + int mIdx = _index.get(idx); try { return OOCCacheManager.requestBlock(getBlockKey(mIdx)).get(); } catch (InterruptedException | ExecutionException e) { @@ -516,9 +526,13 @@ private void registerCacheKey(int blockIdx, BlockKey key, boolean ownsEntry) { throw new IllegalStateException("Invalid cache key registration order"); } - public synchronized void activateIndexing() { - if (_index == null) - _index = new DualHashBidiMap<>(); + public void activateIndexing() { + if(_index != null) + return; + synchronized(this) { + if(_index == null) + _index = new ConcurrentHashMap<>(); + } } @Override @@ -740,4 +754,8 @@ public synchronized void incrProcessingCount(int i, int count) { if (_deletable) tryDeleteBlock(i); } + + public void incrProcessingCount(MatrixIndexes idx, int count) { + incrProcessingCount(_index.get(idx), count); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java index 175d81d6e06..f9c9cb7d7d1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java @@ -52,6 +52,13 @@ public static IndexingOOCInstruction parseInstruction(IndexingCPInstruction cpIn throw new NotImplementedException(); } } + else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) { + if(!cpInst.input1.getDataType().isMatrix()) + throw new NotImplementedException(); + return new MatrixIndexingOOCInstruction(cpInst.input1, cpInst.input2, cpInst.getRowLower(), + cpInst.getRowUpper(), cpInst.getColLower(), cpInst.getColUpper(), cpInst.output, cpInst.getOpcode(), + cpInst.getInstructionString()); + } throw new NotImplementedException(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java index 16338888a54..81f1102811d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MMultOOCInstruction.java @@ -84,7 +84,7 @@ public void processInstruction( ExecutionContext ec ) { int emitAggThreshold = (int)min.getDataCharacteristics().getNumColBlocks(); groupedReduceOOC(intermediateStream, outStream, (left, right) -> { - MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperations(plus, right.getValue()); left.setValue(mb); return left; }, emitAggThreshold); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java index d4851ee2ed1..41f29082ded 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java @@ -80,9 +80,9 @@ public void processInstruction(ExecutionContext ec) { addOutStream(qOut); ec.getMatrixObject(output).setStreamHandle(qOut); - OOCStream qInX = min.getStreamHandle(); - boolean createdCache = !qInX.hasStreamCache(); - CachingStream xCache = createdCache ? new CachingStream(qInX) : qInX.getStreamCache(); + OOCStreamable xStreamable = min.getStreamable(); + boolean createdCache = !xStreamable.hasStreamCache(); + CachingStream xCache = createdCache ? new CachingStream(min.getStreamHandle()) : xStreamable.getStreamCache(); long numRowBlocksL = min.getDataCharacteristics().getNumRowBlocks(); long numColBlocksL = min.getDataCharacteristics().getNumColBlocks(); @@ -144,7 +144,7 @@ public void processInstruction(ExecutionContext ec) { }, tmp -> tmp.getIndexes().getColumnIndex(), tmp -> tmp.getIndexes().getRowIndex()); CompletableFuture reduceXvFuture = groupedReduceOOC(qPartialXv, qXv, (left, right) -> { - MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperations(plus, right.getValue()); left.setValue(mb); return left; }, numColBlocks); @@ -159,7 +159,7 @@ public void processInstruction(ExecutionContext ec) { uFuture = broadcastJoinOOC(qXv, qW, qWeighted, (u, w) -> { MatrixBlock uBlock = (MatrixBlock) u.getValue(); MatrixBlock wBlock = (MatrixBlock) w.getValue().getValue(); - MatrixBlock updated = uBlock.binaryOperationsInPlace(weightOp, wBlock); + MatrixBlock updated = uBlock.binaryOperations(weightOp, wBlock); u.setValue(updated); return u; }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java index 116e65302f1..83d4ba58fd2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java @@ -19,19 +19,29 @@ package org.apache.sysds.runtime.instructions.ooc; -import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.ooc.stream.FilteredOOCStream; +import org.apache.sysds.runtime.ooc.stream.SubOOCStream; import org.apache.sysds.runtime.util.IndexRange; +import org.apache.sysds.runtime.util.UtilFunctions; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -44,6 +54,11 @@ public MatrixIndexingOOCInstruction(CPOperand in, CPOperand rl, CPOperand ru, CP super(in, rl, ru, cl, cu, out, opcode, istr); } + public MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, + CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) { + super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr); + } + @Override public void processInstruction(ExecutionContext ec) { String opcode = getOpcode(); @@ -58,40 +73,64 @@ public void processInstruction(ExecutionContext ec) { boolean inRange = ix.rowStart < mo.getNumRows() && ix.colStart < mo.getNumColumns(); - OOCStream qIn = mo.getStreamHandle(); - OOCStream qOut = createWritableStream(); - - addInStream(qIn); - addOutStream(qOut); - - MatrixObject mOut = ec.getMatrixObject(output); - mOut.setStreamHandle(qOut); - //right indexing if(opcode.equalsIgnoreCase(Opcodes.RIGHT_INDEX.toString())) { - if(output.isScalar() && inRange) { + OOCStream qIn = mo.getStreamHandle(); + addInStream(qIn); + + if(output.isScalar()) { + if(!inRange) + throw new DMLRuntimeException( + "Invalid values for matrix indexing: [" + (ix.rowStart + 1) + ":" + (ix.rowEnd + 1) + "," + + (ix.colStart + 1) + ":" + (ix.colEnd + 1) + "] must be within matrix dimensions [" + + mo.getNumRows() + "x" + mo.getNumColumns() + "]."); + + Double scalarOut = null; IndexedMatrixValue tmp; while((tmp = qIn.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { - if(tmp.getIndexes().getRowIndex() == firstBlockRow && - tmp.getIndexes().getColumnIndex() == firstBlockCol) { - ec.setScalarOutput(output.getName(), new DoubleObject( - tmp.getValue().get((int) ix.rowStart % blocksize, (int) ix.rowEnd % blocksize))); - return; + if(tmp.getIndexes().getRowIndex() == firstBlockRow + 1 && + tmp.getIndexes().getColumnIndex() == firstBlockCol + 1) { + scalarOut = ((MatrixBlock) tmp.getValue()).get((int) (ix.rowStart % blocksize), + (int) (ix.colStart % blocksize)); } } + if(scalarOut == null) + throw new DMLRuntimeException("Desired block not found"); + ec.setScalarOutput(output.getName(), new DoubleObject(scalarOut)); + return; + } - throw new DMLRuntimeException("Desired block not found"); + if(ix.rowStart < 0 || ix.rowStart >= mo.getNumRows() || ix.rowEnd < ix.rowStart || + ix.rowEnd >= mo.getNumRows() || ix.colStart < 0 || ix.colStart >= mo.getNumColumns() || + ix.colEnd < ix.colStart || ix.colEnd >= mo.getNumColumns()) { + String dbg = "inst=\"" + instString + "\", input=" + input1.getName() + ", output=" + output.getName() + + ", rowLower=" + debugScalarOperand(rowLower, ec) + ", rowUpper=" + + debugScalarOperand(rowUpper, ec) + ", colLower=" + debugScalarOperand(colLower, ec) + + ", colUpper=" + debugScalarOperand(colUpper, ec) + ", resolvedRange=[" + (ix.rowStart + 1) + ":" + + (ix.rowEnd + 1) + "," + (ix.colStart + 1) + ":" + (ix.colEnd + 1) + "]" + ", matrixDims=[" + + mo.getNumRows() + "x" + mo.getNumColumns() + "]" + ", blocksize=" + blocksize; + System.out.println("[WARN] OOC rightIndex bounds violation: " + dbg); + throw new DMLRuntimeException( + "Invalid values for matrix indexing: [" + (ix.rowStart + 1) + ":" + (ix.rowEnd + 1) + "," + + (ix.colStart + 1) + ":" + (ix.colEnd + 1) + "] must be within matrix dimensions [" + + mo.getNumRows() + "x" + mo.getNumColumns() + "]. " + dbg); } + MatrixObject mOut = ec.getMatrixObject(output); + ec.getDataCharacteristics(output.getName()).set(ix.rowSpan() + 1, ix.colSpan() + 1, blocksize, -1); + OOCStream qOut = createWritableStream(); + addOutStream(qOut); + mOut.setStreamHandle(qOut); + qIn.setDownstreamMessageRelay(qOut::messageDownstream); qOut.setUpstreamMessageRelay(qIn::messageUpstream); qOut.setIXTransform((downstream, range) -> { - if(downstream){ - long rs = range.rowStart-ix.rowStart+1; - long re = range.rowEnd-ix.rowStart+1; - long cs = range.colStart-ix.colStart+1; - long ce = range.colEnd-ix.colStart+1; + if(downstream) { + long rs = range.rowStart - ix.rowStart + 1; + long re = range.rowEnd - ix.rowStart + 1; + long cs = range.colStart - ix.colStart + 1; + long ce = range.colEnd - ix.colStart + 1; // TODO What happens if range is out of bounds? rs = Math.max(1, rs); cs = Math.max(1, cs); @@ -99,15 +138,31 @@ public void processInstruction(ExecutionContext ec) { ce = Math.min(ix.colSpan(), ce); return new IndexRange(rs, re, cs, ce); } - else{ - long rs = range.rowStart+ix.rowStart; - long re = range.rowEnd+ix.rowStart; - long cs = range.colStart+ix.colStart; - long ce = range.colEnd+ix.colStart; + else { + long rs = range.rowStart + ix.rowStart; + long re = range.rowEnd + ix.rowStart; + long cs = range.colStart + ix.colStart; + long ce = range.colEnd + ix.colStart; return new IndexRange(rs, re, cs, ce); } }); + if(firstBlockRow == lastBlockRow && firstBlockCol == lastBlockCol) { + MatrixIndexes srcBlock = new MatrixIndexes(firstBlockRow + 1, firstBlockCol + 1); + OOCStream filteredStream = new FilteredOOCStream<>(qIn, + tmp -> tmp.getIndexes().equals(srcBlock)); + mapOOC(filteredStream, qOut, tmp -> { + MatrixBlock block = (MatrixBlock) tmp.getValue(); + int rowStartLocal = (int) (ix.rowStart % blocksize); + int rowEndLocal = Math.min(block.getNumRows() - 1, (int) (ix.rowEnd % blocksize)); + int colStartLocal = (int) (ix.colStart % blocksize); + int colEndLocal = Math.min(block.getNumColumns() - 1, (int) (ix.colEnd % blocksize)); + MatrixBlock outBlock = block.slice(rowStartLocal, rowEndLocal, colStartLocal, colEndLocal); + return new IndexedMatrixValue(new MatrixIndexes(1, 1), outBlock); + }); + return; + } + if(ix.rowStart % blocksize == 0 && ix.colStart % blocksize == 0) { // Aligned case: interior blocks can be forwarded directly, borders may require slicing final int outBlockRows = (int) Math.ceil((double) (ix.rowSpan() + 1) / blocksize); @@ -115,16 +170,17 @@ public void processInstruction(ExecutionContext ec) { final int totalBlocks = outBlockRows * outBlockCols; final boolean isCached = qIn.hasStreamCache(); final AtomicInteger producedBlocks = new AtomicInteger(0); - CompletableFuture future = new CompletableFuture<>(); + CompletableFuture future = new CompletableFuture<>(); mapOptionalOOC(qIn, qOut, tmp -> { - if (future.isDone()) + if(future.isDone()) return Optional.empty(); long blockRow = tmp.getIndexes().getRowIndex() - 1; long blockCol = tmp.getIndexes().getColumnIndex() - 1; - boolean within = blockRow >= firstBlockRow && blockRow <= lastBlockRow && - blockCol >= firstBlockCol && blockCol <= lastBlockCol; + boolean within = + blockRow >= firstBlockRow && blockRow <= lastBlockRow && blockCol >= firstBlockCol && + blockCol <= lastBlockCol; if(!within) return Optional.empty(); @@ -173,11 +229,12 @@ public void processInstruction(ExecutionContext ec) { blockCol <= lastBlockCol; if(!pass && !hasIntermediateStream) - qIn.getStreamCache().incrProcessingCount(qIn.getStreamCache().findCachedIndex(tmp.getIndexes()), 1); + qIn.getStreamCache().incrProcessingCount(tmp.getIndexes(), 1); return pass; }); - final CachingStream cachedStream = hasIntermediateStream ? new CachingStream(filteredStream) : qIn.getStreamCache(); + final CachingStream cachedStream = hasIntermediateStream ? new CachingStream( + filteredStream) : qIn.getStreamCache(); cachedStream.activateIndexing(); cachedStream.incrSubscriberCount(1); // We may require re-consumption of blocks (up to 4 times) OOCStream readStream = cachedStream.getReadStream(); @@ -211,7 +268,7 @@ public void processInstruction(ExecutionContext ec) { if(mIdx == null) continue; - try (OOCStream.QueueCallback cb = cachedStream.peekCached(mIdx)) { + try(OOCStream.QueueCallback cb = cachedStream.peekCached(mIdx)) { IndexedMatrixValue mv = cb.get(); MatrixBlock srcBlock = (MatrixBlock) mv.getValue(); @@ -243,16 +300,16 @@ public void processInstruction(ExecutionContext ec) { final int maxConsumptions = aligner.getNumConsumptions(mIdx); Integer con = consumptionCounts.compute(mIdx, (k, v) -> { - if (v == null) + if(v == null) v = 0; - v = v+1; - if (v == maxConsumptions) + v = v + 1; + if(v == maxConsumptions) return null; return v; }); - if (con == null) - cachedStream.incrProcessingCount(cachedStream.findCachedIndex(mIdx), 1); + if(con == null) + cachedStream.incrProcessingCount(mIdx, 1); } } @@ -261,25 +318,251 @@ public void processInstruction(ExecutionContext ec) { if(completed) future.complete(null); - }) - .thenRun(() -> { - aligner.close(); - qOut.closeInput(); - }) - .exceptionally(err -> { - qOut.propagateFailure(DMLRuntimeException.of(err)); - return null; - }); + }).thenRun(() -> { + aligner.close(); + qOut.closeInput(); + }).exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); - if (hasIntermediateStream) + if(hasIntermediateStream) cachedStream.scheduleDeletion(); // We can immediately delete blocks after consumption } - //left indexing else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) { - throw new NotImplementedException(); + MatrixObject mOut = ec.getMatrixObject(output); + ec.getDataCharacteristics(output.getName()).set(mo.getNumRows(), mo.getNumColumns(), blocksize, -1); + if(input2.getDataType().isScalar()) { + if(!ix.isScalar()) + throw new DMLRuntimeException("Invalid index range of scalar leftindexing: " + ix + "."); + if(ix.rowStart < 0 || ix.rowStart >= mo.getNumRows() || ix.colStart < 0 || + ix.colStart >= mo.getNumColumns()) { + throw new DMLRuntimeException( + "Invalid values for matrix indexing: [" + (ix.rowStart + 1) + ":" + (ix.rowEnd + 1) + "," + + (ix.colStart + 1) + ":" + (ix.colEnd + 1) + "] must be within matrix dimensions [" + + mo.getNumRows() + "x" + mo.getNumColumns() + "]."); + } + + final ScalarObject scalar = ec.getScalarInput(input2.getName(), ValueType.FP64, input2.isLiteral()); + final double scalarValue = scalar.getDoubleValue(); + final long targetBlockRow = ix.rowStart / blocksize + 1; + final long targetBlockCol = ix.colStart / blocksize + 1; + final int targetLocalRow = (int) (ix.rowStart % blocksize); + final int targetLocalCol = (int) (ix.colStart % blocksize); + + OOCStream qLhs = mo.getStreamHandle(); + OOCStream qOutRaw = createWritableStream(); + SubOOCStream qOut = new SubOOCStream<>(qOutRaw); + addInStream(qLhs); + addOutStream(qOut); + mOut.setStreamHandle(qOut); + + submitOOCTasks(qLhs, cb -> { + IndexedMatrixValue lhs = cb.get(); + MatrixIndexes idx = lhs.getIndexes(); + if(idx.getRowIndex() != targetBlockRow || idx.getColumnIndex() != targetBlockCol) { + qOut.enqueue(cb.keepOpen()); + return; + } + + MatrixBlock src = (MatrixBlock) lhs.getValue(); + MatrixBlock updated = new MatrixBlock(src); + updated.set(targetLocalRow, targetLocalCol, scalarValue); + updated.examSparsity(); + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(idx), updated)); + }).thenRun(() -> { + qOut.closeInput(); + qOutRaw.closeInput(); + }).exceptionally(err -> { + DMLRuntimeException dmlErr = DMLRuntimeException.of(err); + qOut.propagateFailure(dmlErr); + qOutRaw.propagateFailure(dmlErr); + qOutRaw.closeInput(); + return null; + }); + return; + } + + final MatrixObject rhsMo = ec.getMatrixObject(input2.getName()); + final long lhsRows = mo.getNumRows(); + final long lhsCols = mo.getNumColumns(); + final long rhsRows = rhsMo.getNumRows(); + final long rhsCols = rhsMo.getNumColumns(); + + if(ix.rowSpan() + 1 != rhsRows || ix.colSpan() + 1 != rhsCols) { + throw new DMLRuntimeException( + "Invalid index range of leftindexing: [" + (ix.rowStart + 1) + ":" + (ix.rowEnd + 1) + "," + + (ix.colStart + 1) + ":" + (ix.colEnd + 1) + "] vs [" + rhsRows + "x" + rhsCols + "]."); + } + if(ix.rowStart < 0 || ix.rowStart >= lhsRows || ix.rowEnd < ix.rowStart || ix.rowEnd >= lhsRows || + ix.colStart < 0 || ix.colStart >= lhsCols || ix.colEnd < ix.colStart || ix.colEnd >= lhsCols) { + throw new DMLRuntimeException( + "Invalid values for matrix indexing: [" + (ix.rowStart + 1) + ":" + (ix.rowEnd + 1) + "," + + (ix.colStart + 1) + ":" + (ix.colEnd + 1) + "] must be within matrix dimensions [" + lhsRows + + "x" + lhsCols + "]."); + } + + final IndexRange shiftRange = new IndexRange(ix.rowStart + 1, ix.rowEnd + 1, ix.colStart + 1, + ix.colEnd + 1); + final BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + + OOCStream qLhs = mo.getStreamHandle(); + OOCStream qRhs = rhsMo.getStreamHandle(); + OOCStream qOutRaw = createWritableStream(); + SubOOCStream qOut = new SubOOCStream<>(qOutRaw); + + addInStream(qLhs, qRhs); + addOutStream(qOut); + mOut.setStreamHandle(qOut); + + final Map aggregators = new ConcurrentHashMap<>(); + submitOOCTasks(List.of(qLhs, qRhs), (streamIdx, cb) -> { + if(streamIdx == 0) { + IndexedMatrixValue lhs = cb.get(); + MatrixIndexes lhsIx = lhs.getIndexes(); + if(!UtilFunctions.isInBlockRange(lhsIx, blocksize, shiftRange)) { + qOut.enqueue(cb.keepOpen()); + return; + } + + MatrixIndexes key = new MatrixIndexes(lhsIx); + int expectedRhsContribs = getExpectedRhsContribs(key, shiftRange, blocksize, lhsRows, lhsCols); + LeftIndexAccumulator acc = aggregators.computeIfAbsent(key, + k -> new LeftIndexAccumulator(expectedRhsContribs)); + + IndexRange zeroRange = UtilFunctions.getSelectedRangeForZeroOut(lhs, blocksize, shiftRange); + MatrixBlock lhsZeroed = ((MatrixBlock) lhs.getValue()).zeroOutOperations(new MatrixBlock(), + zeroRange); + + MatrixBlock out = acc.addLhs(lhsZeroed, plus); + if(out != null) { + if(!aggregators.remove(key, acc)) + throw new DMLRuntimeException( + "Failed to remove completed LEFT_INDEX accumulator for " + key); + out.examSparsity(); + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(key), out)); + } + } + else { + IndexedMatrixValue rhs = cb.get(); + ArrayList shifted = new ArrayList<>(); + OperationsOnMatrixValues.performShift(rhs, shiftRange, blocksize, lhsRows, lhsCols, shifted); + + for(IndexedMatrixValue part : shifted) { + MatrixIndexes key = new MatrixIndexes(part.getIndexes()); + LeftIndexAccumulator acc = aggregators.computeIfAbsent(key, k -> new LeftIndexAccumulator( + getExpectedRhsContribs(k, shiftRange, blocksize, lhsRows, lhsCols))); + + MatrixBlock out = acc.addRhs((MatrixBlock) part.getValue(), plus); + if(out != null) { + if(!aggregators.remove(key, acc)) + throw new DMLRuntimeException( + "Failed to remove completed LEFT_INDEX accumulator for " + key); + out.examSparsity(); + qOut.enqueue(new IndexedMatrixValue(new MatrixIndexes(key), out)); + } + } + } + }).thenRun(() -> { + if(!aggregators.isEmpty()) + throw new DMLRuntimeException( + "LEFT_INDEX finished with unfinished aggregators: " + aggregators.size()); + qOut.closeInput(); + qOutRaw.closeInput(); + }).exceptionally(err -> { + DMLRuntimeException dmlErr = DMLRuntimeException.of(err); + qOut.propagateFailure(dmlErr); + qOutRaw.propagateFailure(dmlErr); + qOutRaw.closeInput(); + return null; + }); } else throw new DMLRuntimeException( "Invalid opcode (" + opcode + ") encountered in MatrixIndexingOOCInstruction."); } + + private static String debugScalarOperand(CPOperand op, ExecutionContext ec) { + try { + return op.getName() + "=" + ec.getScalarInput(op).getStringValue() + (op.isLiteral() ? " [lit]" : " [var]"); + } + catch(Exception ex) { + return op.getName() + "="; + } + } + + private static int getExpectedRhsContribs(MatrixIndexes lhsIx, IndexRange shift, int bs, long lhsRows, + long lhsCols) { + + long lrs = UtilFunctions.computeCellIndex(lhsIx.getRowIndex(), bs, 0); + long lcs = UtilFunctions.computeCellIndex(lhsIx.getColumnIndex(), bs, 0); + long lre = lrs + UtilFunctions.computeBlockSize(lhsRows, lhsIx.getRowIndex(), bs) - 1; + long lce = lcs + UtilFunctions.computeBlockSize(lhsCols, lhsIx.getColumnIndex(), bs) - 1; + + long ors = Math.max(lrs, shift.rowStart), ore = Math.min(lre, shift.rowEnd); + long ocs = Math.max(lcs, shift.colStart), oce = Math.min(lce, shift.colEnd); + if(ors > ore || ocs > oce) + return 0; + + long rhsRowStart = ors - shift.rowStart + 1; + long rhsColStart = ocs - shift.colStart + 1; + long rowLen = ore - ors + 1; + long colLen = oce - ocs + 1; + + long rBlocks = UtilFunctions.computeBlockIndex(rhsRowStart + rowLen - 1, bs) - + UtilFunctions.computeBlockIndex(rhsRowStart, bs) + 1; + long cBlocks = UtilFunctions.computeBlockIndex(rhsColStart + colLen - 1, bs) - + UtilFunctions.computeBlockIndex(rhsColStart, bs) + 1; + + return Math.toIntExact(rBlocks * cBlocks); + } + + private static class LeftIndexAccumulator { + private final int _expectedRhsContribs; + private MatrixBlock _lhs; + private MatrixBlock _rhsAgg; + private int _rhsCtr; + private boolean _lhsSeen; + private boolean _emitted; + + private LeftIndexAccumulator(int expectedRhsContribs) { + _expectedRhsContribs = expectedRhsContribs; + _rhsCtr = 0; + _lhsSeen = false; + _emitted = false; + } + + public synchronized MatrixBlock addLhs(MatrixBlock lhs, BinaryOperator plus) { + if(_lhsSeen) + throw new DMLRuntimeException("Duplicate LEFT_INDEX lhs contribution encountered"); + _lhs = lhs; + _lhsSeen = true; + return emitIfReady(plus); + } + + public synchronized MatrixBlock addRhs(MatrixBlock rhs, BinaryOperator plus) { + if(_emitted) + throw new DMLRuntimeException("LEFT_INDEX accumulator received rhs after completion"); + _rhsCtr++; + if(_rhsCtr > _expectedRhsContribs) + throw new DMLRuntimeException( + "LEFT_INDEX accumulator rhs overflow: " + _rhsCtr + " > " + _expectedRhsContribs); + if(_rhsAgg == null) + _rhsAgg = rhs; + else + _rhsAgg = _rhsAgg.binaryOperationsInPlace(plus, rhs); + return emitIfReady(plus); + } + + private MatrixBlock emitIfReady(BinaryOperator plus) { + if(_emitted || !_lhsSeen || _rhsCtr < _expectedRhsContribs) + return null; + if(_rhsCtr > _expectedRhsContribs) + throw new DMLRuntimeException("LEFT_INDEX accumulator encountered invalid rhs contribution count"); + _emitted = true; + if(_rhsAgg != null) + _lhs = _lhs.binaryOperationsInPlace(plus, _rhsAgg); + return _lhs; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index f7cefe635df..8f1f6fef24d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -26,6 +26,8 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.OOCInstructionParser; +import org.apache.sysds.runtime.instructions.cp.CPInstruction; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; @@ -46,18 +48,21 @@ import scala.Tuple5; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.TreeMap; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.LongAdder; import java.util.function.BiConsumer; import java.util.function.BiFunction; @@ -130,10 +135,19 @@ public String getGraphString() { @Override public Instruction preprocessInstruction(ExecutionContext ec) { - if (DMLScript.OOC_LOG_EVENTS) - nanoTime = System.nanoTime(); - // TODO - return super.preprocessInstruction(ec); + Instruction tmp = super.preprocessInstruction(ec); + + // Keep consistent with CP/SP instruction patching semantics. + if(tmp.requiresLabelUpdate()) { + String updInst = CPInstruction.updateLabels(tmp.toString(), ec.getVariables()); + tmp = OOCInstructionParser.parseSingleInstruction(updInst); + } + + // Record event start timestamp on the instruction instance that will execute. + if (DMLScript.OOC_LOG_EVENTS && tmp instanceof OOCInstruction) + ((OOCInstruction) tmp).nanoTime = System.nanoTime(); + + return tmp; } @Override @@ -191,6 +205,7 @@ protected OOCStream mergeOOCStreams(List> streams) { return new MergedOOCStream<>(streams); } + @SuppressWarnings({"varargs", "unchecked"}) protected OOCStream mergeOOCStreams(OOCStream... streams) { return new MergedOOCStream<>(streams); } @@ -208,6 +223,43 @@ protected CompletableFuture mapOOC(OOCStream qIn, OOCStream q return mapOptionalOOC(qIn, qOut, tmp -> Optional.of(mapper.apply(tmp))); } + protected CompletableFuture expandOOC(OOCStream qIn, OOCStream qOut, Function> op) { + addInStream(qIn); + addOutStream(qOut); + + AtomicInteger deferredCtr = new AtomicInteger(1); + CompletableFuture future = new CompletableFuture<>(); + + submitOOCTasks(qIn, tmp -> { + Collection out; + try(tmp) { + out = op.apply(tmp.get()); + } + if(!out.isEmpty()) { + deferredCtr.getAndIncrement(); + TaskContext.defer(() -> { + out.forEach(qOut::enqueue); + if(deferredCtr.decrementAndGet() == 0) + future.complete(null); + }); + } + }) + .thenRun(() -> { + if(deferredCtr.decrementAndGet() == 0) + future.complete(null); + }) + .exceptionally(err -> { + future.completeExceptionally(err); + return null; + }); + + return future.thenRun(qOut::closeInput).exceptionally(err -> { + DMLRuntimeException dmlErr = DMLRuntimeException.of(err); + qOut.propagateFailure(dmlErr); + throw dmlErr; + }); + } + protected CompletableFuture mapOptionalOOC(OOCStream qIn, OOCStream qOut, Function> optionalMapper) { addInStream(qIn); addOutStream(qOut); @@ -358,12 +410,12 @@ protected CompletableFuture broadcastJoinOOC(OOCStream CompletableFuture broadcastJoinOOC(OOCStream { availableBroadcastInput.forEach((k, v) -> { - rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1); + rightCache.incrProcessingCount(v.idx, 1); }); availableBroadcastInput.clear(); qOut.closeInput(); @@ -478,10 +530,10 @@ protected CompletableFuture joinManyOOC(OOCStream CompletableFuture pipeOOC(List> queues, BiConsu }, (i, tmp) -> {}); } + protected static final class ScanStep { + private final R _out; + private final C _carry; + + protected ScanStep(R out, C carry) { + _out = out; + _carry = carry; + } + + protected R getOut() { + return _out; + } + + protected C getCarry() { + return _carry; + } + } + + /** + * Performs a sequential scan in a given order. This method is currently fully in-memory and thus should not yet + * be used for large sequences. TODO + * @param qIn the (unsorted) input stream + * @param qOut the target output stream + * @param seqFn a sequence function to define a unique order + * @param scanner the scanner function that takes the current sequence item and a carry object and returns output and next carry + * @param sequenceSize the sequence size + * @return + * @param the return tile + * @param the carry state + */ + protected CompletableFuture scanOOC(OOCStream qIn, OOCStream qOut, + Function seqFn, + BiFunction> scanner, long sequenceSize) { + if(sequenceSize <= 0) + throw new DMLRuntimeException("Invalid ordered scan size: " + sequenceSize); + + addOutStream(qOut); + TreeMap> pending = new TreeMap<>(); + long[] expected = new long[] {1L}; + Object[] carry = new Object[1]; + AtomicLong processed = new AtomicLong(0); + + CompletableFuture future = pipeOOC(qIn, cb -> { + List ready = new ArrayList<>(); + List> processedCallbacks = new ArrayList<>(); + try { + synchronized(pending) { + OOCStream.QueueCallback pinned = cb.keepOpen(); + IndexedMatrixValue item = pinned.get(); + long seqIx = seqFn.apply(item); + if(seqIx < 1 || seqIx > sequenceSize) { + pinned.close(); + throw new DMLRuntimeException("Ordered scan index out of bounds: " + seqIx + "/" + sequenceSize); + } + + OOCStream.QueueCallback prior = pending.put(seqIx, pinned); + if(prior != null) { + pending.put(seqIx, prior); + pinned.close(); + throw new DMLRuntimeException("Duplicate ordered scan item for sequence index " + seqIx + "."); + } + + while(pending.containsKey(expected[0])) { + OOCStream.QueueCallback curCb = pending.remove(expected[0]); + try { + IndexedMatrixValue cur = curCb.get(); + @SuppressWarnings("unchecked") + C agg = (C) carry[0]; + ScanStep step = scanner.apply(cur, agg); + if(step == null) + throw new DMLRuntimeException("Ordered scan step must not be null."); + R out = step.getOut(); + if(out == null) + throw new DMLRuntimeException("Ordered scan output must not be null."); + carry[0] = step.getCarry(); + ready.add(out); + expected[0]++; + processed.incrementAndGet(); + } + finally { + processedCallbacks.add(curCb); + } + } + } + + for(R out : ready) + qOut.enqueue(out); + } + finally { + processedCallbacks.forEach(OOCStream.QueueCallback::close); + } + }); + + return future.handle((r, err) -> { + synchronized(pending) { + if(err != null) { + pending.values().forEach(OOCStream.QueueCallback::close); + pending.clear(); + throw DMLRuntimeException.of(err); + } + + if(processed.get() != sequenceSize || !pending.isEmpty()) { + Object pendingKeys = pending.keySet().toString(); + pending.values().forEach(OOCStream.QueueCallback::close); + pending.clear(); + throw new DMLRuntimeException( + "Incomplete ordered scan processing. processed=" + processed.get() + ", expected=" + + sequenceSize + ", pendingKeys=" + pendingKeys); + } + + if(expected[0] != sequenceSize + 1) { + pending.values().forEach(OOCStream.QueueCallback::close); + pending.clear(); + throw new DMLRuntimeException("Missing ordered scan items: expected terminal index=" + + (sequenceSize + 1) + ", seen=" + expected[0]); + } + } + qOut.closeInput(); + return null; + }); + } + + protected CompletableFuture scanOOC(OOCStream qIn, OOCStream qOut, + Function seqFn, BiFunction scanner, + Function carryFn, long sequenceSize) { + return scanOOC(qIn, qOut, seqFn, (IndexedMatrixValue item, C carry) -> { + R out = scanner.apply(item, carry); + if(out == null) + throw new DMLRuntimeException("Ordered scan output must not be null."); + return new ScanStep<>(out, carryFn.apply(out)); + }, sequenceSize); + } + protected CompletableFuture submitOOCTasks(final List> queues, BiConsumer> consumer) { return submitOOCTasks(queues, consumer, null, null); } @@ -878,7 +1063,7 @@ protected CompletableFuture submitOOCTasks(final List> qu COMPUTE_IN_FLIGHT.incrementAndGet(); try { Runnable oocTask = oocTask(() -> { - long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + long taskStartTime = DMLScript.STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; try(pinned) { consumer.accept(k, pinned); @@ -894,9 +1079,9 @@ protected CompletableFuture submitOOCTasks(final List> qu Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); _localStatisticsAdder.reset(); } - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); } }, localFuture, streamContext); COMPUTE_EXECUTOR.submit(oocTask); @@ -929,7 +1114,7 @@ protected CompletableFuture submitOOCTasks(final List> qu COMPUTE_IN_FLIGHT.incrementAndGet(); try { Runnable oocTask = oocTask(() -> { - long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + long taskStartTime = DMLScript.STATISTICS || DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; try(pinnedGroup) { for(int idx = 0; idx < pinnedGroup.size(); idx++) { OOCStream.QueueCallback sub = pinnedGroup.getCallback(idx); @@ -950,9 +1135,9 @@ protected CompletableFuture submitOOCTasks(final List> qu Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); _localStatisticsAdder.reset(); } - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); } }, localFuture, streamContext); COMPUTE_EXECUTOR.submit(oocTask); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java index 69e669a40b7..289ef8e6b87 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java @@ -19,6 +19,14 @@ package org.apache.sysds.runtime.instructions.ooc; +import org.apache.sysds.runtime.ooc.cache.BlockEntry; +import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; +import org.apache.sysds.runtime.ooc.cache.OOCCacheScheduler; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; @@ -31,20 +39,25 @@ */ public final class OOCWatchdog { public static final boolean WATCH = false; + private static final double PINNED_NEAR_LIMIT_RATIO = 0.9; + private static final int TOP_PINNED_STREAMS = 5; private static final ConcurrentHashMap OPEN = new ConcurrentHashMap<>(); - private static final ScheduledExecutorService EXEC = - Executors.newSingleThreadScheduledExecutor(r -> { - Thread t = new Thread(r, "TemporaryWatchdog"); - t.setDaemon(true); - return t; - }); + private static final ScheduledExecutorService EXEC = Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "TemporaryWatchdog"); + t.setDaemon(true); + return t; + }); private static final long STALE_MS = TimeUnit.SECONDS.toMillis(10); private static final long SCAN_INTERVAL_MS = TimeUnit.SECONDS.toMillis(10); + private static final long CACHE_SCAN_INTERVAL_MS = TimeUnit.SECONDS.toMillis(1); static { - if (WATCH) + if(WATCH) { EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS, SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS); + EXEC.scheduleAtFixedRate(OOCWatchdog::scanCachePressure, CACHE_SCAN_INTERVAL_MS, CACHE_SCAN_INTERVAL_MS, + TimeUnit.MILLISECONDS); + } } private OOCWatchdog() { @@ -57,7 +70,7 @@ public static void registerOpen(String id, String desc, String context, OOCStrea public static void addEvent(String id, String eventMsg) { Entry e = OPEN.get(id); - if (e != null) + if(e != null) e.events.add(eventMsg); } @@ -67,17 +80,74 @@ public static void registerClose(String id) { private static void scan() { long now = System.currentTimeMillis(); - for (Map.Entry e : OPEN.entrySet()) { - if (now - e.getValue().openedAt >= STALE_MS) { - if (e.getValue().events.isEmpty() && !(e.getValue().stream instanceof CachingStream)) + for(Map.Entry e : OPEN.entrySet()) { + if(now - e.getValue().openedAt >= STALE_MS) { + if(e.getValue().events.isEmpty() && !(e.getValue().stream instanceof CachingStream)) continue; // Probably just a stream that has no consumer (remains to be checked why this can happen) - System.err.println("[TemporaryWatchdog] Still open after " + (now - e.getValue().openedAt) + "ms: " - + e.getKey() + " (" + e.getValue().desc + ")" - + (e.getValue().context != null ? " ctx=" + e.getValue().context : "")); + System.err.println( + "[TemporaryWatchdog] Still open after " + (now - e.getValue().openedAt) + "ms: " + e.getKey() + + " (" + e.getValue().desc + ")" + + (e.getValue().context != null ? " ctx=" + e.getValue().context : "")); } } } + private static void scanCachePressure() { + OOCCacheScheduler cache = OOCCacheManager.getCacheIfInitialized(); + if(cache == null) + return; + + long hardLimit = cache.getHardLimit(); + if(hardLimit <= 0) + return; + long pinnedBytes = cache.getPinnedBytes(); + if(pinnedBytes < (long) (hardLimit * PINNED_NEAR_LIMIT_RATIO)) + return; + + Collection snapshot = cache.snapshot(); + if(snapshot.isEmpty()) + return; + + HashMap pinnedByStream = new HashMap<>(); + long pinnedBlocks = 0; + for(BlockEntry entry : snapshot) { + if(!entry.isPinned()) + continue; + long streamId = entry.getKey().getStreamId(); + StreamPinStats stats = pinnedByStream.computeIfAbsent(streamId, sid -> new StreamPinStats()); + stats.bytes += entry.getSize(); + stats.blocks++; + pinnedBlocks++; + } + + if(pinnedByStream.isEmpty()) + return; + + ArrayList> top = new ArrayList<>(pinnedByStream.entrySet()); + top.sort(Comparator.comparingLong((Map.Entry e) -> e.getValue().bytes).reversed()); + + StringBuilder sb = new StringBuilder(); + sb.append("[WARN] OOCWatchdog: pinned memory near hard limit: "); + sb.append(toMiB(pinnedBytes)).append("MiB / ").append(toMiB(hardLimit)).append("MiB (") + .append(String.format("%.1f", 100.0 * pinnedBytes / hardLimit)).append("%)"); + sb.append(", pinned blocks=").append(pinnedBlocks).append(", top streams=["); + + int n = Math.min(TOP_PINNED_STREAMS, top.size()); + for(int i = 0; i < n; i++) { + Map.Entry e = top.get(i); + if(i > 0) + sb.append("; "); + sb.append(e.getKey()).append(": ").append(toMiB(e.getValue().bytes)).append("MiB (") + .append(e.getValue().blocks).append(" blocks)"); + } + sb.append("]"); + System.err.println(sb); + } + + private static long toMiB(long bytes) { + return bytes / (1024 * 1024); + } + private static class Entry { final String desc; final String context; @@ -93,4 +163,9 @@ private static class Entry { this.events = new ConcurrentLinkedQueue<>(); } } + + private static class StreamPinStats { + long bytes; + long blocks; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java index b1d397d919a..87b17f77192 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java @@ -22,6 +22,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; +import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; @@ -35,10 +36,13 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.matrix.operators.SimpleOperator; +import org.apache.sysds.runtime.util.UtilFunctions; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -73,6 +77,10 @@ public static ParameterizedBuiltinOOCInstruction parseInstruction(String str) { else if(opcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { return new ParameterizedBuiltinOOCInstruction(null, paramsMap, out, opcode, str); } + else if(opcode.equalsIgnoreCase(Opcodes.REXPAND.toString())) { + func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode); + return new ParameterizedBuiltinOOCInstruction(new SimpleOperator(func), paramsMap, out, opcode, str); + } else throw new NotImplementedException(); // TODO } @@ -132,6 +140,32 @@ else if(instOpcode.equalsIgnoreCase(Opcodes.CONTAINS.toString())) { } ec.setScalarOutput(output.getName(), new BooleanObject(ret)); + } + else if(instOpcode.equalsIgnoreCase(Opcodes.REXPAND.toString())) { + MatrixObject targetObj = ec.getMatrixObject(params.get("target")); + OOCStream qIn = targetObj.getStreamHandle(); + OOCStream qOut = createWritableStream(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + String maxValName = params.get("max"); + long lmaxVal = maxValName.startsWith(Lop.SCALAR_VAR_NAME_PREFIX) ? + ec.getScalarInput(maxValName, Types.ValueType.FP64, false).getLongValue() : + UtilFunctions.toLong(Double.parseDouble(maxValName)); + boolean dirRows = params.get("dir").equals("rows"); + boolean cast = Boolean.parseBoolean(params.get("cast")); + boolean ignore = Boolean.parseBoolean(params.get("ignore")); + long blen = targetObj.getBlocksize(); + + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); + + expandOOC(qIn, qOut, tmp -> { + ArrayList out = new ArrayList<>(); + LibMatrixReorg.rexpand(tmp, lmaxVal, dirRows, cast, ignore, blen, out); + return out; + }); + } + else + throw new NotImplementedException(); } } -} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java index e861f7afc57..db61aeaddae 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReorgOOCInstruction.java @@ -41,17 +41,31 @@ public class ReorgOOCInstruction extends ComputationOOCInstruction { private final CPOperand _col; private final CPOperand _desc; private final CPOperand _ixret; + // reshape-specific attributes + private final CPOperand _opRows; + private final CPOperand _opCols; + private final CPOperand _opDims; + private final CPOperand _opByRow; protected ReorgOOCInstruction(ReorgOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { - this(op, in1, out, null, null, null, opcode, istr); + this(op, in1, out, null, null, null, null, null, null, null, opcode, istr); + } + + private ReorgOOCInstruction(Operator op, CPOperand in, CPOperand out, CPOperand opRows, CPOperand opCols, + CPOperand opDims, CPOperand opByRow, String opcode, String istr) { + this(op, in, out, null, null, null, opRows, opCols, opDims, opByRow, opcode, istr); } private ReorgOOCInstruction(Operator op, CPOperand in, CPOperand out, CPOperand col, CPOperand desc, CPOperand ixret, - String opcode, String istr) { + CPOperand opRows, CPOperand opCols, CPOperand opDims, CPOperand opByRow, String opcode, String istr) { super(OOCType.Reorg, op, in, out, opcode, istr); _col = col; _desc = desc; _ixret = ixret; + _opRows = opRows; + _opCols = opCols; + _opDims = opDims; + _opByRow = opByRow; } public static ReorgOOCInstruction parseInstruction(String str) { @@ -61,7 +75,7 @@ public static ReorgOOCInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; - if (opcode.equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) { + if(opcode.equalsIgnoreCase(Opcodes.TRANSPOSE.toString())) { InstructionUtils.checkNumFields(str, 2, 3); in.split(parts[1]); out.split(parts[2]); @@ -69,26 +83,46 @@ public static ReorgOOCInstruction parseInstruction(String str) { ReorgOperator reorg = new ReorgOperator(SwapIndex.getSwapIndexFnObject()); return new ReorgOOCInstruction(reorg, in, out, opcode, str); } - else if (opcode.equalsIgnoreCase(Opcodes.SORT.toString())) { - InstructionUtils.checkNumFields(str, 5,6); + else if(opcode.equalsIgnoreCase(Opcodes.SORT.toString())) { + InstructionUtils.checkNumFields(str, 5, 6); in.split(parts[1]); out.split(parts[5]); CPOperand col = new CPOperand(parts[2]); CPOperand desc = new CPOperand(parts[3]); CPOperand ixret = new CPOperand(parts[4]); int k = Integer.parseInt(parts[6]); - return new ReorgOOCInstruction(new ReorgOperator(new SortIndex(1,false,false), k), - in, out, col, desc, ixret, opcode, str); + return new ReorgOOCInstruction(new ReorgOperator(new SortIndex(1, false, false), k), + in, out, col, desc, ixret, null, null, null, null, opcode, str); + } + else if(opcode.equalsIgnoreCase(Opcodes.RESHAPE.toString())) { + InstructionUtils.checkNumFields(parts, 6); + in.split(parts[1]); + CPOperand rows = new CPOperand(parts[2]); + CPOperand cols = new CPOperand(parts[3]); + CPOperand dims = new CPOperand(parts[4]); + CPOperand byRow = new CPOperand(parts[5]); + out.split(parts[6]); + return new ReorgOOCInstruction(new Operator(true), in, out, rows, cols, dims, byRow, opcode, str); } else throw new NotImplementedException(); } public void processInstruction( ExecutionContext ec ) { - // Create thread and process the transpose operation - MatrixObject min = ec.getMatrixObject(input1); - + if(getOpcode().equalsIgnoreCase(Opcodes.RESHAPE.toString())) { + // TODO Make reshape truly out-of-core + int rows = (int) ec.getScalarInput(_opRows).getLongValue(); + int cols = (int) ec.getScalarInput(_opCols).getLongValue(); + boolean byRow = ec.getScalarInput(_opByRow).getBooleanValue(); + MatrixBlock in = ec.getMatrixInput(input1.getName()); + MatrixBlock out = in.reshape(rows, cols, byRow); + ec.releaseMatrixInput(input1.getName()); + ec.setMatrixOutput(output.getName(), out); + return; + } + // Create thread and process the transpose/sort operation + MatrixObject min = ec.getMatrixObject(input1); ReorgOperator r_op = (ReorgOperator) _optr; if(r_op.fn instanceof SortIndex) { @@ -107,7 +141,7 @@ public void processInstruction( ExecutionContext ec ) { ec.releaseMatrixInput(_col.getName()); ec.releaseMatrixInput(input1.getName()); ec.setMatrixOutput(output.getName(), soresBlock); - } else if (r_op.fn instanceof SwapIndex) { + } else if(r_op.fn instanceof SwapIndex) { OOCStream qIn = min.getStreamHandle(); OOCStream qOut = createWritableStream(); ec.getMatrixObject(output).setStreamHandle(qOut); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index 493aba06c72..548e80df942 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -34,6 +34,14 @@ public class TeeOOCInstruction extends ComputationOOCInstruction { public static void reset() { if (!refCtr.isEmpty()) { System.err.println("There are some dangling streams still in the cache: " + refCtr); + for(CachingStream cache : refCtr.keySet()) { + try { + cache.scheduleDeletion(); + } + catch(Exception ex) { + System.err.println("Failed to schedule deletion for dangling stream " + cache + ": " + ex.getMessage()); + } + } refCtr.clear(); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java index b9c7612bfe9..a95cf0ec333 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -19,13 +19,22 @@ package org.apache.sysds.runtime.instructions.ooc; +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.LibMatrixAgg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.meta.DataCharacteristics; public class UnaryOOCInstruction extends ComputationOOCInstruction { private UnaryOperator _uop = null; @@ -57,17 +66,113 @@ public void processInstruction( ExecutionContext ec ) { UnaryOperator uop = (UnaryOperator) _uop; // Create thread and process the unary operation MatrixObject min = ec.getMatrixObject(input1); + boolean cumSumProd = Builtin.isBuiltinCode(uop.fn, BuiltinCode.CUMSUMPROD); + ec.getDataCharacteristics(output.getName()).set(min.getNumRows(), cumSumProd ? 1 : min.getNumColumns(), + min.getBlocksize(), -1); OOCStream qIn = min.getStreamHandle(); - OOCStream qOut = createWritableStream(); + OOCStream qOut; + boolean cumulative = isCumulativeUnary(uop); + + if(cumulative) { + qOut = processCumulativeUnaryInstruction(ec, uop, qIn); + } + else { + qOut = createWritableStream(); + mapOOC(qIn, qOut, tmp -> { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), tmp.getValue().unaryOperations(uop, new MatrixBlock())); + return tmpOut; + }); + } + ec.getMatrixObject(output).setStreamHandle(qOut); - qIn.setDownstreamMessageRelay(qOut::messageDownstream); - qOut.setUpstreamMessageRelay(qIn::messageUpstream); - - mapOOC(qIn, qOut, tmp -> { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().unaryOperations(uop, new MatrixBlock())); - return tmpOut; - }); + if(!cumulative) { + qIn.setDownstreamMessageRelay(qOut::messageDownstream); + qOut.setUpstreamMessageRelay(qIn::messageUpstream); + } + } + + private OOCStream processCumulativeUnaryInstruction(ExecutionContext ec, UnaryOperator uop, + OOCStream qIn) { + DataCharacteristics dc = ec.getDataCharacteristics(input1.getName()); + if(!dc.dimsKnown()) + throw new DMLRuntimeException( + "OOC cumulative unary operations require known dimensions for deterministic block ordering."); + + BuiltinCode bcode = ((Builtin)uop.fn).getBuiltinCode(); + long rowBlocks = dc.getNumRowBlocks(); + long colBlocks = dc.getNumColBlocks(); + boolean rowCum = (bcode == BuiltinCode.ROWCUMSUM); + boolean sumProd = (bcode == BuiltinCode.CUMSUMPROD); + if(sumProd && colBlocks != 1) + throw new DMLRuntimeException( + "Unsupported OOC cumulative sum-product with more than one column block: " + colBlocks); + + long outerSize = rowCum ? rowBlocks : colBlocks; + long innerSize = rowCum ? colBlocks : rowBlocks; + if(outerSize > Integer.MAX_VALUE) + throw new DMLRuntimeException( + "Unsupported number of cumulative partitions: " + outerSize + " (max " + Integer.MAX_VALUE + ")."); + + int partitions = Math.toIntExact(outerSize); + List> splitInputs = splitOOCStream(qIn, imv -> { + long outerIx = rowCum ? imv.getIndexes().getRowIndex() : imv.getIndexes().getColumnIndex(); + return (int) (outerIx - 1); + }, partitions); + + List> splitOutputs = new ArrayList<>(partitions); + + for(int i = 0; i < partitions; i++) { + OOCStream partOut = createWritableStream(); + splitOutputs.add(partOut); + + this.scanOOC(splitInputs.get(i), partOut, + imv -> rowCum ? imv.getIndexes().getColumnIndex() : imv.getIndexes().getRowIndex(), (imv, agg) -> { + MatrixBlock inBlk = (MatrixBlock) imv.getValue(); + int outRows = inBlk.getNumRows(); + int outCols = sumProd ? 1 : inBlk.getNumColumns(); + MatrixBlock outBlk = LibMatrixAgg.cumaggregateUnaryMatrix(inBlk, + new MatrixBlock(outRows, outCols, false), uop, agg); + MatrixIndexes idx = imv.getIndexes(); + IndexedMatrixValue out = new IndexedMatrixValue(new MatrixIndexes(idx.getRowIndex(), idx.getColumnIndex()), + outBlk); + double[] nextCarry = rowCum ? extractLastColumn(outBlk) : extractLastRow(outBlk); + return new ScanStep<>(out, nextCarry); + }, innerSize).exceptionally(t -> { + partOut.propagateFailure(DMLRuntimeException.of(t)); + return null; + }); + } + + return mergeOOCStreams(splitOutputs); + } + + private static double[] extractLastRow(MatrixBlock blk) { + int rows = blk.getNumRows(); + int cols = blk.getNumColumns(); + double[] ret = new double[cols]; + if(rows == 0 || cols == 0) + return ret; + int lr = rows - 1; + for(int j = 0; j < cols; j++) + ret[j] = blk.get(lr, j); + return ret; + } + + private static double[] extractLastColumn(MatrixBlock blk) { + int rows = blk.getNumRows(); + int cols = blk.getNumColumns(); + double[] ret = new double[rows]; + if(rows == 0 || cols == 0) + return ret; + int lc = cols - 1; + for(int i = 0; i < rows; i++) + ret[i] = blk.get(i, lc); + return ret; + } + + private static boolean isCumulativeUnary(UnaryOperator uop) { + return Builtin.isBuiltinCode(uop.fn, BuiltinCode.CUMSUM, BuiltinCode.ROWCUMSUM, BuiltinCode.CUMPROD, + BuiltinCode.CUMMIN, BuiltinCode.CUMMAX, BuiltinCode.CUMSUMPROD); } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java index 901e043d985..fbc7d642238 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java @@ -137,6 +137,17 @@ synchronized int pin() { return _pinCount; } + /** + * Tries to increment pin-count if already pinned. Unpinned entries are not affected + * by this operation. This allows bypassing the global cache lock. + */ + synchronized boolean fastPin() { + if(_pinCount == 0) + return false; + _pinCount++; + return true; + } + /** * Unpins the underlying data * @return true if the data is now unpinned @@ -148,6 +159,17 @@ synchronized boolean unpin() { return _pinCount == 0; } + /** + * Tries to unpin but guarantees that it will not + * remove the last pin. This allows bypassing the global cache lock. + */ + synchronized boolean fastUnpin() { + if(_pinCount <= 1) + return false; + _pinCount--; + return true; + } + public String toString() { return "Entry" + _key.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java index 26e8f010341..a0c8bb075a8 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -23,6 +23,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; @@ -55,6 +56,7 @@ public class OOCCacheManager { } public static void reset() { + TeeOOCInstruction.reset(); OOCIOHandler ioHandler = _ioHandler.getAndSet(null); OOCCacheScheduler cacheScheduler = _scheduler.getAndSet(null); if (ioHandler != null) @@ -94,7 +96,7 @@ public static OOCCacheScheduler getCache() { return scheduler; OOCIOHandler ioHandler = new OOCMatrixIOHandler(); - scheduler = new OOCLRUCacheScheduler(ioHandler, _evictionLimit, _hardLimit); + scheduler = new OOCLRUCacheScheduler(ioHandler, _evictionLimit, _hardLimit, Math.max(40000000, (long)((_hardLimit - _evictionLimit) * 0.1))); if(_scheduler.compareAndSet(null, scheduler)) { _ioHandler.set(ioHandler); @@ -103,6 +105,14 @@ public static OOCCacheScheduler getCache() { } } + /** + * Returns the current cache scheduler if already initialized, otherwise null. + * This method does not trigger lazy initialization. + */ + public static OOCCacheScheduler getCacheIfInitialized() { + return _scheduler.get(); + } + public static OOCIOHandler getIOHandler() { OOCIOHandler io = _ioHandler.get(); if(io != null) diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java index bafda48f4d4..9cc108db5e6 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.ooc.cache; +import java.util.Collection; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -125,6 +126,16 @@ BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size, */ long getCacheSize(); + /** + * Returns the number of pinned bytes in the cache. + */ + long getPinnedBytes(); + + /** + * Returns the hard cache limit in bytes. + */ + long getHardLimit(); + /** * Returns if the current cache size is within its defined memory limits. */ @@ -144,4 +155,10 @@ BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size, * Updates the cache limits. */ void updateLimits(long evictionLimit, long hardLimit); + + /** + * Creates a snapshot of the cache. + * Should only be used for debugging or diagnoses. + */ + Collection snapshot(); } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java index 813dcd1d804..a204cd16db0 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java @@ -32,28 +32,42 @@ import java.util.Collections; import java.util.Deque; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; public class OOCLRUCacheScheduler implements OOCCacheScheduler { private static final boolean SANITY_CHECKS = false; private static final Log LOG = LogFactory.getLog(OOCLRUCacheScheduler.class.getName()); private final OOCIOHandler _ioHandler; - private final LinkedHashMap _cache; + private final HashMap _cache; private final HashMap _evictionCache; private final DeferredReadQueue _deferredReadRequests; private final Deque _processingReadRequests; private final HashMap _blockReads; - private long _hardLimit; + private volatile long _hardLimit; private long _evictionLimit; + private long _readBuffer; private final int _callerId; - private long _cacheSize; + private volatile long _cacheSize; private long _bytesUpForEviction; + private long _pinnedBytes; + private long _pinnedEvictingBytes; + private long _readingReservedBytes; + private long _warmPinnedBytes; private volatile boolean _running; private boolean _warnThrottling; + private long _lastEvictRun; + private volatile int _deferredReadCountHint; + private final AtomicBoolean _maintenanceRunning; + private final AtomicBoolean _maintenanceRequested; + private final AtomicBoolean _maintenanceNeedsIncr; - public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long hardLimit) { + public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long hardLimit, long readBuffer) { this._ioHandler = ioHandler; this._cache = new LinkedHashMap<>(1024, 0.75f, true); this._evictionCache = new HashMap<>(); @@ -62,10 +76,20 @@ public OOCLRUCacheScheduler(OOCIOHandler ioHandler, long evictionLimit, long har this._blockReads = new HashMap<>(); this._hardLimit = hardLimit; this._evictionLimit = evictionLimit; + this._readBuffer = readBuffer; this._cacheSize = 0; this._bytesUpForEviction = 0; + this._pinnedEvictingBytes = 0; + this._pinnedBytes = 0; + this._readingReservedBytes = 0; + this._warmPinnedBytes = 0; + this._lastEvictRun = System.currentTimeMillis(); + this._deferredReadCountHint = 0; this._running = true; this._warnThrottling = false; + this._maintenanceRunning = new AtomicBoolean(false); + this._maintenanceRequested = new AtomicBoolean(false); + this._maintenanceNeedsIncr = new AtomicBoolean(false); this._callerId = DMLScript.OOC_LOG_EVENTS ? OOCEventLog.registerCaller("LRUCacheScheduler") : 0; if (DMLScript.OOC_LOG_EVENTS) { @@ -92,7 +116,7 @@ public CompletableFuture request(BlockKey key) { synchronized(entry) { if (entry.getState().isAvailable()) { - if (entry.pin() == 0) + if (pinEntryWithAccounting(entry) == 0) throw new IllegalStateException(); couldPin = true; } @@ -175,7 +199,7 @@ public CompletableFuture> request(List keys, boolean if(allAvailable) { for(BlockEntry entry : entries) { synchronized(entry) { - if(entry.pin() == 0) + if(pinEntryWithAccounting(entry) == 0) throw new IllegalStateException(); } } @@ -227,11 +251,10 @@ private void scheduleDeferredRead(DeferredReadRequest deferredReadRequest) { synchronized(this) { double score = 0; int readyCount = 0; - for (BlockEntry entry : deferredReadRequest.getEntries()) { - synchronized(entry) { - if (entry.getState().isAvailable()) - readyCount++; - } + for(BlockEntry entry : deferredReadRequest.getEntries()) { + // Snapshot for scheduling heuristic only; exact state will be checked when reserving. + if(entry.getState().isAvailable()) + readyCount++; BlockReadState state = _blockReads.get(entry.getKey()); if (state != null) score += state.priority; @@ -241,10 +264,11 @@ private void scheduleDeferredRead(DeferredReadRequest deferredReadRequest) { if (!deferredReadRequest.getEntries().isEmpty()) score += ((double) readyCount) / deferredReadRequest.getEntries().size(); deferredReadRequest.setPriorityScore(score); - _deferredReadRequests.add(deferredReadRequest); + _deferredReadCountHint = _deferredReadRequests.size(); } - onCacheSizeChanged(false); // To schedule deferred reads if possible + onCacheSizeChanged(true); // Apply pressure from deferred read demand. + onCacheSizeChanged(false); // Attempt to schedule deferred reads. } @Override @@ -286,6 +310,11 @@ private BlockEntry put(BlockKey key, Object data, long size, boolean pin, OOCIOH if (avail != null || _evictionCache.containsKey(key)) throw new IllegalStateException("Cannot overwrite existing entries: " + key); _cacheSize += size; + if(pin) { + _pinnedBytes += size; + if(entry.getState() == BlockState.WARM) + _warmPinnedBytes += entry.getSize(); + } } onCacheSizeChanged(true); return entry; @@ -309,9 +338,12 @@ public void forget(BlockKey key) { shouldScheduleDeletion = entry.getState().isBackedByDisk() || entry.getState() == BlockState.EVICTING; cacheSizeDelta = transitionMemState(entry, BlockState.REMOVED); + if(entry.isPinned() && entry.getDataUnsafe() != null) + _pinnedBytes -= entry.getSize(); + if(_pinnedBytes < 0) + throw new IllegalStateException(); entry.setDataUnsafe(null); } - } } if (cacheSizeDelta != 0) @@ -322,60 +354,64 @@ public void forget(BlockKey key) { @Override public void pin(BlockEntry entry) { - if (!this._running) + if(!this._running) throw new IllegalStateException("Cache scheduler has been shut down."); + if(entry.fastPin()) + return; // Try to avoid using global lock first - int pinCount = entry.pin(); - if (pinCount == 0) - throw new IllegalStateException("Could not pin the requested entry: " + entry.getKey()); synchronized(this) { + synchronized(entry) { + int pinCount = pinEntryWithAccounting(entry); + if (pinCount == 0) + throw new IllegalStateException("Could not pin the requested entry: " + entry.getKey()); + } // Access element in cache for Lru - _cache.get(entry.getKey()); + //_cache.get(entry.getKey()); } } @Override public void unpin(BlockEntry entry) { - boolean couldFree = entry.unpin(); - - if (couldFree) { - long cacheSizeDelta = 0; - boolean shouldCheckEviction = false; - synchronized(this) { + if(entry.fastUnpin()) + return; // Try to avoid using global lock first + long cacheSizeDelta = 0; + boolean shouldCheckEviction = false; + synchronized(this) { + synchronized(entry) { + if(!unpinEntryWithAccounting(entry)) + return; if (_cacheSize <= _evictionLimit) return; // Nothing to do + if (entry.isPinned()) + return; // Pin state changed so we cannot evict - synchronized(entry) { - if (entry.isPinned()) - return; // Pin state changed so we cannot evict - - if (entry.getState().isAvailable() && entry.getState().isBackedByDisk()) { - if (entry.getRetainHintCount() > 0) { - shouldCheckEviction = true; - } - else { - cacheSizeDelta = transitionMemState(entry, BlockState.COLD); - long cleared = entry.clear(); - if (cleared != entry.getSize()) - throw new IllegalStateException(); - _cache.remove(entry.getKey()); - _evictionCache.put(entry.getKey(), entry); - } - } else if (entry.getState() == BlockState.HOT) { - if (entry.getRetainHintCount() > 0) { - shouldCheckEviction = true; - } - else { - cacheSizeDelta = onUnpinnedHotBlockUnderMemoryPressure(entry); - } + if (entry.getState().isAvailable() && entry.getState().isBackedByDisk()) { + if (entry.getRetainHintCount() > 0) { + shouldCheckEviction = true; + } + else { + cacheSizeDelta = transitionMemState(entry, BlockState.COLD); + long cleared = entry.clear(); + if (cleared != entry.getSize()) + throw new IllegalStateException(); + _cache.remove(entry.getKey()); + _evictionCache.put(entry.getKey(), entry); + } + } + else if (entry.getState() == BlockState.HOT) { + if (entry.getRetainHintCount() > 0) { + shouldCheckEviction = true; + } + else { + cacheSizeDelta = onUnpinnedHotBlockUnderMemoryPressure(entry); } } } - if (cacheSizeDelta != 0) - onCacheSizeChanged(cacheSizeDelta > 0); - else if (shouldCheckEviction) - onCacheSizeChanged(true); } + if (cacheSizeDelta != 0) + onCacheSizeChanged(cacheSizeDelta > 0); + else if (shouldCheckEviction) + onCacheSizeChanged(true); } @Override @@ -383,6 +419,16 @@ public synchronized long getCacheSize() { return _cacheSize; } + @Override + public synchronized long getPinnedBytes() { + return _pinnedBytes; + } + + @Override + public synchronized long getHardLimit() { + return _hardLimit; + } + @Override public boolean isWithinLimits() { return _cacheSize < _hardLimit; @@ -398,14 +444,25 @@ public synchronized void shutdown() { this._running = false; if(!_cache.isEmpty() || !_evictionCache.isEmpty()) { System.out.println("[WARN] Cache still holds " + _cache.size() + " / " + _evictionCache.size() + " blocks"); + + Set cachedStreams = _cache.keySet().stream().map(BlockKey::getStreamId).collect(Collectors.toSet()); + Set evictedStreams = _evictionCache.keySet().stream().map(BlockKey::getStreamId).collect(Collectors.toSet()); + cachedStreams.addAll(evictedStreams); + System.out.println("[WARN] Affected stream IDs: " + cachedStreams + ", Pinned: " + _cache.values().stream().mapToInt( + e -> e.isPinned() ? 1 : 0).sum()); } _cache.clear(); _evictionCache.clear(); _processingReadRequests.clear(); _deferredReadRequests.clear(); + _deferredReadCountHint = 0; _blockReads.clear(); _cacheSize = 0; _bytesUpForEviction = 0; + _pinnedBytes = 0; + _pinnedEvictingBytes = 0; + _readingReservedBytes = 0; + _warmPinnedBytes = 0; } @Override @@ -414,23 +471,60 @@ public synchronized void updateLimits(long evictionLimit, long hardLimit) { _hardLimit = hardLimit; } + @Override + public synchronized Collection snapshot() { + int l = _cache.size() + _evictionCache.size(); + ArrayList out = new ArrayList<>(l); + out.addAll(_cache.values()); + out.addAll(_evictionCache.values()); + return out; + } + /** * Must be called while this cache and the corresponding entry are locked */ private long onUnpinnedHotBlockUnderMemoryPressure(BlockEntry entry) { long cacheSizeDelta = transitionMemState(entry, BlockState.EVICTING); - evict(entry); + evict(entry, true); return cacheSizeDelta; } private void onCacheSizeChanged(boolean incr) { - if (incr) + if(incr) + _maintenanceNeedsIncr.set(true); + _maintenanceRequested.set(true); + if(!_maintenanceRunning.compareAndSet(false, true)) + return; + + runMaintenanceLoop(); + } + + private void runMaintenanceLoop() { + while(true) { + try { + do { + _maintenanceRequested.set(false); + onCacheSizeChangedInternal(_maintenanceNeedsIncr.getAndSet(false)); + } while(_maintenanceRequested.get()); + } + finally { + _maintenanceRunning.set(false); + } + + // Re-check in case a request came in after releasing the running flag. + if(!(_maintenanceRequested.get() && _maintenanceRunning.compareAndSet(false, true))) + return; + } + } + + private void onCacheSizeChangedInternal(boolean incr) { + if(incr) onCacheSizeIncremented(); - else { + else while(onCacheSizeDecremented()) {} - } - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onCacheSizeChangedEvent(_callerId, System.nanoTime(), _cacheSize, _bytesUpForEviction); + if(DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onCacheSizeChangedEvent(_callerId, System.nanoTime(), _cacheSize, _bytesUpForEviction, + _pinnedBytes, _readingReservedBytes); } private synchronized void sanityCheck() { @@ -454,15 +548,27 @@ else if (_warnThrottling && _cacheSize < _hardLimit) { int total = 0; long actualCacheSize = 0; long upForEviction = 0; + long actualPinnedBytes = 0; + long actualPinnedEvictingBytes = 0; + long actualWarmPinnedBytes = 0; + long actualReadingReservedBytes = 0; for (BlockEntry entry : _cache.values()) { - if (entry.isPinned()) + if (entry.isPinned()) { pinned++; + actualPinnedBytes += entry.getSize(); + if(entry.getState() == BlockState.WARM) + actualWarmPinnedBytes += entry.getSize(); + } if (entry.getState().isBackedByDisk()) backedByDisk++; if (entry.getState() == BlockState.EVICTING) { evicting++; upForEviction += entry.getSize(); + if(entry.isPinned()) + actualPinnedEvictingBytes += entry.getSize(); } + if(entry.getState() == BlockState.READING) + actualReadingReservedBytes += entry.getSize(); if (!entry.getState().isAvailable()) throw new IllegalStateException(); total++; @@ -471,13 +577,32 @@ else if (_warnThrottling && _cacheSize < _hardLimit) { for (BlockEntry entry : _evictionCache.values()) { if (entry.getState().isAvailable()) throw new IllegalStateException("Invalid eviction state: " + entry.getState()); + if (entry.getState() == BlockState.EVICTING && entry.isPinned()) + actualPinnedEvictingBytes += entry.getSize(); if (entry.getState() == BlockState.READING) actualCacheSize += entry.getSize(); + if (entry.getState() == BlockState.READING) + actualReadingReservedBytes += entry.getSize(); + if (entry.isPinned()) { + actualPinnedBytes += entry.getSize(); + if(entry.getState() == BlockState.WARM) + actualWarmPinnedBytes += entry.getSize(); + } } if (actualCacheSize != _cacheSize) throw new IllegalStateException(actualCacheSize + " != " + _cacheSize); if (upForEviction != _bytesUpForEviction) throw new IllegalStateException(upForEviction + " != " + _bytesUpForEviction); + if (actualPinnedBytes != _pinnedBytes) + throw new IllegalStateException(actualPinnedBytes + " != " + _pinnedBytes); + if (actualPinnedEvictingBytes != _pinnedEvictingBytes) + throw new IllegalStateException(actualPinnedEvictingBytes + " != " + _pinnedEvictingBytes); + if (_pinnedEvictingBytes > _bytesUpForEviction) + throw new IllegalStateException(_pinnedEvictingBytes + " > " + _bytesUpForEviction); + if(actualWarmPinnedBytes != _warmPinnedBytes) + throw new IllegalStateException(actualWarmPinnedBytes + " != " + _warmPinnedBytes); + if (actualReadingReservedBytes != _readingReservedBytes) + throw new IllegalStateException(actualReadingReservedBytes + " != " + _readingReservedBytes); System.out.println("=========="); System.out.println("Limit: " + _evictionLimit/1000 + "KB"); System.out.println("Memory: (" + _cacheSize/1000 + "KB - " + _bytesUpForEviction/1000 + "KB) / " + _hardLimit/1000 + "KB"); @@ -487,43 +612,57 @@ else if (_warnThrottling && _cacheSize < _hardLimit) { } private void onCacheSizeIncremented() { + if(System.currentTimeMillis() - _lastEvictRun < 5) + return; // Debounce (at least 5ms) long cacheSizeDelta = 0; - List upForEviction; + List upForEvictionNeedsWrite; + List upForEvictionNoWrite; synchronized(this) { - if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + long pressure = _cacheSize + _readBuffer - _bytesUpForEviction - _warmPinnedBytes; + if(pressure <= _evictionLimit) return; // Nothing to do + long overshoot = Math.max((long)(0.1 * _evictionLimit), 10000000); + long lowLimit = _evictionLimit - _readBuffer - overshoot; + + //System.out.println("[CACHE] Claiming " + (pressure + overshoot - _evictionLimit)/1000 + "kB (last claim was " + (System.currentTimeMillis() - _lastEvictRun) + "ms ago)"); + // Scan for values that can be evicted Collection entries = _cache.values(); List toRemove = new ArrayList<>(); - upForEviction = new ArrayList<>(); + upForEvictionNeedsWrite = new ArrayList<>(); + upForEvictionNoWrite = new ArrayList<>(); for(int pass = 0; pass < 2; pass++) { boolean allowRetainHint = pass == 1; for(BlockEntry entry : entries) { - if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + if(getEvictionPressure() <= lowLimit) break; synchronized(entry) { - if(entry.isPinned()) - continue; + //if(entry.isPinned()) + // continue; if(!allowRetainHint && entry.getRetainHintCount() > 0) continue; if(entry.getState() == BlockState.COLD || entry.getState() == BlockState.EVICTING) continue; - if(entry.getState().isBackedByDisk()) { + if(entry.getState().isBackedByDisk() && !entry.isPinned()) { cacheSizeDelta += transitionMemState(entry, BlockState.COLD); entry.clear(); toRemove.add(entry); } else { + boolean needsWrite = !entry.getState().isBackedByDisk(); cacheSizeDelta += transitionMemState(entry, BlockState.EVICTING); - upForEviction.add(entry); + if(needsWrite) + upForEvictionNeedsWrite.add(entry); + else + upForEvictionNoWrite.add(entry); } } } - if(_cacheSize - _bytesUpForEviction <= _evictionLimit) + if(getEvictionPressure() <= lowLimit) break; } @@ -533,23 +672,31 @@ private void onCacheSizeIncremented() { } sanityCheck(); + _lastEvictRun = System.currentTimeMillis(); } - for (BlockEntry entry : upForEviction) { - evict(entry); - } + for (BlockEntry entry : upForEvictionNeedsWrite) + evict(entry, true); + for (BlockEntry entry : upForEvictionNoWrite) + evict(entry, false); if (cacheSizeDelta != 0) onCacheSizeChanged(cacheSizeDelta > 0); } + private long getEvictionPressure() { + return _cacheSize + _readBuffer - _bytesUpForEviction; + } + private boolean onCacheSizeDecremented() { + if(_cacheSize + 10000000 >= _hardLimit || _deferredReadCountHint == 0) + return false; boolean allReserved = true; boolean reading = false; List> toRead; DeferredReadRequest req; synchronized(this) { - if(_cacheSize >= _hardLimit || _deferredReadRequests.isEmpty()) + if(_cacheSize + 10000000 >= _hardLimit || _deferredReadRequests.isEmpty()) return false; // Nothing to do // Try to schedule the next disk read @@ -563,7 +710,7 @@ private boolean onCacheSizeDecremented() { BlockEntry entry = req.getEntries().get(idx); synchronized(entry) { if(entry.getState().isAvailable()) { - if(entry.pin() == 0) + if(pinEntryWithAccounting(entry) == 0) throw new IllegalStateException(); req.setPinned(idx); } @@ -589,6 +736,7 @@ else if (entry.getState() == BlockState.READING) { if(allReserved) { _deferredReadRequests.poll(); + _deferredReadCountHint = _deferredReadRequests.size(); if (!toRead.isEmpty()) _processingReadRequests.add(req); } @@ -606,6 +754,7 @@ else if(allReserved && reading && req.isComplete()) { synchronized(this) { _processingReadRequests.remove(req); _deferredReadRequests.remove(req); + _deferredReadCountHint = _deferredReadRequests.size(); } req.getFuture().complete(req.getEntries()); return true; @@ -628,9 +777,10 @@ else if(allReserved && reading && req.isComplete()) { else { LOG.error("Uncaught CacheError", t); } + onCacheSizeChanged(false); return; } - java.util.Set completedRequests = new java.util.HashSet<>(); + Set completedRequests = new HashSet<>(); synchronized(this) { synchronized(r) { transitionMemState(r, BlockState.WARM); @@ -642,7 +792,7 @@ else if(allReserved && reading && req.isComplete()) { if(state != null) { for(DeferredReadWaiter waiter : state.waiters) { synchronized(r) { - if(r.pin() == 0) + if(pinEntryWithAccounting(r) == 0) throw new IllegalStateException(); if(waiter.request.setPinned(waiter.index) || waiter.request.isComplete()) completedRequests.add(waiter.request); @@ -655,18 +805,24 @@ else if(allReserved && reading && req.isComplete()) { _processingReadRequests.remove(done); _deferredReadRequests.remove(done); } + _deferredReadCountHint = _deferredReadRequests.size(); sanityCheck(); } for(DeferredReadRequest done : completedRequests) done.getFuture().complete(done.getEntries()); + onCacheSizeChanged(false); }); } return false; } - private void evict(final BlockEntry entry) { + private void evict(final BlockEntry entry, boolean needsWrite) { + if(!needsWrite) { + onEvicted(entry); + return; + } CompletableFuture future = _ioHandler.scheduleEviction(entry); future.whenComplete((r, e) -> onEvicted(entry)); } @@ -718,14 +874,19 @@ private long transitionMemState(BlockEntry entry, BlockState newState) { long sz = entry.getSize(); long oldCacheSize = _cacheSize; + boolean pinned = entry.isPinned(); // Remove old contribution switch (oldState) { case REMOVED: throw new IllegalStateException(); case HOT: + _cacheSize -= sz; + break; case WARM: _cacheSize -= sz; + if(pinned) + _warmPinnedBytes -= entry.getSize(); break; case EVICTING: _cacheSize -= sz; @@ -733,6 +894,7 @@ private long transitionMemState(BlockEntry entry, BlockState newState) { break; case READING: _cacheSize -= sz; + _readingReservedBytes -= sz; break; case COLD: break; @@ -744,8 +906,12 @@ private long transitionMemState(BlockEntry entry, BlockState newState) { case COLD: break; case HOT: + _cacheSize += sz; + break; case WARM: _cacheSize += sz; + if(pinned) + _warmPinnedBytes += entry.getSize(); break; case EVICTING: _cacheSize += sz; @@ -753,13 +919,67 @@ private long transitionMemState(BlockEntry entry, BlockState newState) { break; case READING: _cacheSize += sz; + _readingReservedBytes += sz; break; } + if(oldState == BlockState.EVICTING && entry.isPinned()) + _pinnedEvictingBytes -= sz; + if(newState == BlockState.EVICTING && entry.isPinned()) + _pinnedEvictingBytes += sz; + if(_pinnedEvictingBytes < 0) + throw new IllegalStateException(); + if(_pinnedEvictingBytes > _bytesUpForEviction) + throw new IllegalStateException(_pinnedEvictingBytes + " > " + _bytesUpForEviction); + entry.setState(newState); return _cacheSize - oldCacheSize; } + /** + * Requires scheduler lock. + */ + private int pinEntryWithAccounting(BlockEntry entry) { + int pinCount = entry.pin(); + if(pinCount == 1) { + _pinnedBytes += entry.getSize(); + switch(entry.getState()) { + case EVICTING: + _pinnedEvictingBytes += entry.getSize(); + break; + case WARM: + _warmPinnedBytes += entry.getSize(); + break; + } + } + return pinCount; + } + + /** + * Requires scheduler lock and entry lock. + * @return true if this call transitioned pin count to zero. + */ + private boolean unpinEntryWithAccounting(BlockEntry entry) { + boolean couldFree = entry.unpin(); + // Second check (entry.getDataUnsafe()...) is needed for potential forget(...) calls + if(couldFree && entry.getDataUnsafe() != null) { + _pinnedBytes -= entry.getSize(); + switch(entry.getState()) { + case EVICTING: + _pinnedEvictingBytes -= entry.getSize(); + break; + case WARM: + _warmPinnedBytes -= entry.getSize(); + break; + } + } + if(_pinnedBytes < 0) + throw new IllegalStateException(); + if(_pinnedEvictingBytes < 0) + throw new IllegalStateException(); + return couldFree; + } + private void registerWaiter(BlockKey key, DeferredReadRequest request, int index) { BlockReadState state = _blockReads.computeIfAbsent(key, k -> new BlockReadState()); state.waiters.add(new DeferredReadWaiter(request, index)); diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java index aca99ed0966..3146439165f 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java @@ -666,6 +666,7 @@ private void evictTask(CloseableQueue } } catch(IOException | InterruptedException ex) { + ex.printStackTrace(); throw new DMLRuntimeException(ex); } catch(Exception ignored) { diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java b/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java index 2279272afa6..130d7a27503 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stats/OOCEventLog.java @@ -89,9 +89,10 @@ public static void onDiskReadEvent(int callerId, long startTimestamp, long endTi _data[idx] = size; } - public static void onCacheSizeChangedEvent(int callerId, long timestamp, long cacheSize, long bytesToEvict) { - int idx = _logCtr.getAndIncrement(); - if(idx >= _eventTypes.length) + public static void onCacheSizeChangedEvent(int callerId, long timestamp, long cacheSize, long bytesToEvict, + long pinnedBytes, long readingReservedBytes) { + int idx = _logCtr.getAndAdd(2); + if(idx + 1 >= _eventTypes.length) return; _eventTypes[idx] = EventType.CACHESIZE_CHANGE; _startTimestamps[idx] = timestamp; @@ -99,6 +100,14 @@ public static void onCacheSizeChangedEvent(int callerId, long timestamp, long ca _callerIds[idx] = callerId; _threadIds[idx] = Thread.currentThread().getId(); _data[idx] = cacheSize; + + int idxCont = idx + 1; + _eventTypes[idxCont] = EventType.CACHESIZE_CHANGE_CONT; + _startTimestamps[idxCont] = timestamp; + _endTimestamps[idxCont] = pinnedBytes; + _callerIds[idxCont] = callerId; + _threadIds[idxCont] = Thread.currentThread().getId(); + _data[idxCont] = readingReservedBytes; } public static void putRunSetting(String setting, Object data) { @@ -118,7 +127,36 @@ public static String getDiskWriteEventsCSV() { } public static String getCacheSizeEventsCSV() { - return getFilteredCSV("ThreadID,CallerID,Timestamp,ScheduledEvictionSize,CacheSize\n", EventType.CACHESIZE_CHANGE, true); + StringBuilder sb = new StringBuilder(); + sb.append("ThreadID,CallerID,Timestamp,ScheduledEvictionSize,CacheSize,PinnedSize,ReadReservedSize\n"); + + int maxIdx = Math.min(_logCtr.get(), _eventTypes.length); + for (int i = 0; i < maxIdx; i++) { + if (_eventTypes[i] != EventType.CACHESIZE_CHANGE) + continue; + long pinnedSize = 0; + long readReservedSize = 0; + if(i + 1 < maxIdx && _eventTypes[i + 1] == EventType.CACHESIZE_CHANGE_CONT) { + pinnedSize = _endTimestamps[i + 1]; + readReservedSize = _data[i + 1]; + } + sb.append(_threadIds[i]); + sb.append(','); + sb.append(_callerNames.get(_callerIds[i])); + sb.append(','); + sb.append(_startTimestamps[i]); + sb.append(','); + sb.append(_endTimestamps[i]); + sb.append(','); + sb.append(_data[i]); + sb.append(','); + sb.append(pinnedSize); + sb.append(','); + sb.append(readReservedSize); + sb.append('\n'); + } + + return sb.toString(); } private static String getFilteredCSV(String header, EventType filter, boolean data) { @@ -182,6 +220,7 @@ public enum EventType { COMPUTE, DISK_WRITE, DISK_READ, - CACHESIZE_CHANGE + CACHESIZE_CHANGE, + CACHESIZE_CHANGE_CONT } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java index 7d0a27932f1..907f78795a2 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java @@ -218,7 +218,7 @@ public void messageDownstream(OOCStreamMessage msg) { @Override public void setUpstreamMessageRelay(Consumer relay) { - throw new UnsupportedOperationException(); + _taskQueue.setUpstreamMessageRelay(relay); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java index d1453506a2b..b0ef37280f9 100644 --- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java @@ -90,6 +90,7 @@ import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.fed.FEDInstruction; import org.apache.sysds.runtime.instructions.gpu.GPUInstruction; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.spark.SPInstruction; import org.apache.sysds.runtime.lineage.Lineage; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -484,7 +485,7 @@ public static Instruction cloneInstruction( Instruction oInst, long pid, boolean try { if( oInst instanceof CPInstruction || oInst instanceof SPInstruction || oInst instanceof FEDInstruction - || oInst instanceof GPUInstruction ) { + || oInst instanceof GPUInstruction || oInst instanceof OOCInstruction ) { if( oInst instanceof FunctionCallCPInstruction && cpFunctions ) { FunctionCallCPInstruction tmp = (FunctionCallCPInstruction) oInst; if( !plain ) { diff --git a/src/test/java/org/apache/sysds/test/component/ooc/cache/BlockEntryTestAccess.java b/src/test/java/org/apache/sysds/test/component/ooc/cache/BlockEntryTestAccess.java new file mode 100644 index 00000000000..9b9d95064fd --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/ooc/cache/BlockEntryTestAccess.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.ooc.cache; + +import org.apache.sysds.runtime.ooc.cache.BlockEntry; +import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.BlockState; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; + +final class BlockEntryTestAccess { + private static final Constructor CTOR; + private static final Method GET_DATA_UNSAFE; + private static final Method SET_DATA_UNSAFE; + private static final Method SET_STATE; + + static { + try { + CTOR = BlockEntry.class.getDeclaredConstructor(BlockKey.class, long.class, Object.class); + CTOR.setAccessible(true); + + GET_DATA_UNSAFE = BlockEntry.class.getDeclaredMethod("getDataUnsafe"); + GET_DATA_UNSAFE.setAccessible(true); + + SET_DATA_UNSAFE = BlockEntry.class.getDeclaredMethod("setDataUnsafe", Object.class); + SET_DATA_UNSAFE.setAccessible(true); + + SET_STATE = BlockEntry.class.getDeclaredMethod("setState", BlockState.class); + SET_STATE.setAccessible(true); + } + catch(ReflectiveOperationException e) { + throw new ExceptionInInitializerError(e); + } + } + + private BlockEntryTestAccess() { + // Utility class + } + + static BlockEntry newBlockEntry(BlockKey key, long size, Object data) { + try { + return CTOR.newInstance(key, size, data); + } + catch(ReflectiveOperationException e) { + throw new RuntimeException("Failed to create BlockEntry via reflection", e); + } + } + + static Object getDataUnsafe(BlockEntry entry) { + try { + return GET_DATA_UNSAFE.invoke(entry); + } + catch(ReflectiveOperationException e) { + throw new RuntimeException("Failed to call BlockEntry#getDataUnsafe via reflection", e); + } + } + + static void setDataUnsafe(BlockEntry entry, Object data) { + try { + SET_DATA_UNSAFE.invoke(entry, data); + } + catch(ReflectiveOperationException e) { + throw new RuntimeException("Failed to call BlockEntry#setDataUnsafe via reflection", e); + } + } + + static void setState(BlockEntry entry, BlockState state) { + try { + SET_STATE.invoke(entry, state); + } + catch(ReflectiveOperationException e) { + throw new RuntimeException("Failed to call BlockEntry#setState via reflection", e); + } + } +} diff --git a/src/test/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheSchedulerTest.java b/src/test/java/org/apache/sysds/test/component/ooc/cache/OOCLRUCacheSchedulerTest.java similarity index 66% rename from src/test/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheSchedulerTest.java rename to src/test/java/org/apache/sysds/test/component/ooc/cache/OOCLRUCacheSchedulerTest.java index 66a46c03269..bdd75c1411a 100644 --- a/src/test/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheSchedulerTest.java +++ b/src/test/java/org/apache/sysds/test/component/ooc/cache/OOCLRUCacheSchedulerTest.java @@ -17,10 +17,15 @@ * under the License. */ -package org.apache.sysds.runtime.ooc.cache; +package org.apache.sysds.test.component.ooc.cache; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.ooc.cache.BlockEntry; +import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.BlockState; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.ooc.cache.OOCLRUCacheScheduler; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -46,7 +51,7 @@ public class OOCLRUCacheSchedulerTest { @Before public void setUp() { _handler = new FakeIOHandler(); - _scheduler = new OOCLRUCacheScheduler(_handler, 0, Long.MAX_VALUE); + _scheduler = new OOCLRUCacheScheduler(_handler, 0, Long.MAX_VALUE, 40000000); } @After @@ -60,7 +65,7 @@ public void tearDown() { @Test public void testImmediateRequestPinsBlock() throws Exception { FakeIOHandler handler = new FakeIOHandler(); - OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(handler, Long.MAX_VALUE, Long.MAX_VALUE); + OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(handler, Long.MAX_VALUE, Long.MAX_VALUE, 40000000); try { BlockKey key = new BlockKey(1, 1); scheduler.put(key, new Object(), ENTRY_SIZE); @@ -70,6 +75,7 @@ public void testImmediateRequestPinsBlock() throws Exception { Assert.assertTrue(fetched.isPinned()); scheduler.unpin(fetched); Assert.assertEquals(ENTRY_SIZE, scheduler.getCacheSize()); + scheduler.forget(key); } finally { scheduler.shutdown(); @@ -95,6 +101,52 @@ public void testDeferredReadSingleBlock() throws Exception { Assert.assertEquals(ENTRY_SIZE, _scheduler.getCacheSize()); _scheduler.unpin(fetched); Assert.assertEquals(0, _scheduler.getCacheSize()); + _scheduler.forget(key); + } + + @Test + public void testDeferredReadDemandAppliesEvictionPressure() throws Exception { + FakeIOHandler handler = new FakeIOHandler(); + OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(handler, 0, Long.MAX_VALUE, 40000000); + BlockKey coldKey = new BlockKey(9, 1); + BlockKey warmKey = new BlockKey(9, 2); + try { + BlockEntry cold = putColdSourceBacked(scheduler, coldKey); + Assert.assertEquals(BlockState.COLD, cold.getState()); + + OOCIOHandler.SourceBlockDescriptor desc = new OOCIOHandler.SourceBlockDescriptor( + "unused", Types.FileFormat.BINARY, new MatrixIndexes(1, 1), 0, 0, ENTRY_SIZE); + BlockEntry warm = scheduler.putAndPinSourceBacked(warmKey, new Object(), ENTRY_SIZE, desc); + Assert.assertEquals(BlockState.WARM, warm.getState()); + + // Keep WARM in-memory at the hard limit first, then tighten soft limit afterwards. + scheduler.updateLimits(ENTRY_SIZE, ENTRY_SIZE); + scheduler.unpin(warm); + Assert.assertEquals(BlockState.WARM, warm.getState()); + Assert.assertEquals(ENTRY_SIZE, scheduler.getCacheSize()); + + // Deferred reads are only scheduled when hard limit is above the 10MB guard. + scheduler.updateLimits(0, 20000000L); + + CompletableFuture future = scheduler.request(coldKey); + Assert.assertFalse(future.isDone()); + waitForReadCount(handler, coldKey, 1); + // Under current scheduler policy this may remain WARM if no immediate demotion is required. + Assert.assertFalse(warm.isPinned()); + + handler.completeRead(coldKey); + BlockEntry fetched = future.get(WAIT_TIMEOUT_SEC, TimeUnit.SECONDS); + Assert.assertTrue(fetched.isPinned()); + scheduler.unpin(fetched); + safeForget(scheduler, coldKey); + safeForget(scheduler, warmKey); + } + finally { + safeForget(scheduler, coldKey); + safeForget(scheduler, warmKey); + scheduler.shutdown(); + handler.shutdown(); + } } @Test @@ -131,11 +183,14 @@ public void testMergeOverlappingRequests() throws Exception { resA.forEach(_scheduler::unpin); resB.forEach(_scheduler::unpin); Assert.assertEquals(0, _scheduler.getCacheSize()); + _scheduler.forget(key1); + _scheduler.forget(key2); + _scheduler.forget(key3); } @Test public void testPrioritizeReordersDeferredRequests() throws Exception { - OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(_handler, 0, 0); + OOCLRUCacheScheduler scheduler = new OOCLRUCacheScheduler(_handler, 0, 0, 0); try { BlockKey key1 = new BlockKey(1, 1); BlockKey key2 = new BlockKey(1, 2); @@ -154,13 +209,35 @@ public void testPrioritizeReordersDeferredRequests() throws Exception { scheduler.prioritize(key3, 1); List after = snapshotDeferredOrder(scheduler); - Assert.assertEquals(List.of(key1, key3, key2), after); + Assert.assertEquals(List.of(key3, key1, key2), after); + scheduler.forget(key1); + scheduler.forget(key2); + scheduler.forget(key3); } finally { scheduler.shutdown(); } } + private static void waitForReadCount(FakeIOHandler handler, BlockKey key, int expected) throws InterruptedException { + long timeoutNanos = TimeUnit.SECONDS.toNanos(WAIT_TIMEOUT_SEC); + long start = System.nanoTime(); + while (System.nanoTime() - start < timeoutNanos) { + if (handler.getReadCount(key) == expected) + return; + Thread.sleep(1); + } + Assert.assertEquals(expected, handler.getReadCount(key)); + } + + private static void safeForget(OOCLRUCacheScheduler scheduler, BlockKey key) { + try { + scheduler.forget(key); + } + catch (RuntimeException ignored) { + } + } + private BlockEntry putColdSourceBacked(BlockKey key) { return putColdSourceBacked(_scheduler, key); } @@ -178,9 +255,44 @@ private BlockEntry putColdSourceBacked(OOCLRUCacheScheduler scheduler, BlockKey private static List snapshotDeferredOrder(OOCLRUCacheScheduler scheduler) throws Exception { Field field = OOCLRUCacheScheduler.class.getDeclaredField("_deferredReadRequests"); field.setAccessible(true); - Deque deque = (Deque) field.get(scheduler); + Object queue = field.get(scheduler); + List requests = new ArrayList<>(); + + if (queue instanceof Deque) { + requests.addAll((Deque) queue); + } + else { + Field heapField = queue.getClass().getDeclaredField("_heap"); + heapField.setAccessible(true); + requests.addAll((List) heapField.get(queue)); + + if (!requests.isEmpty()) { + Class reqClass = requests.get(0).getClass(); + Field priorityField = reqClass.getDeclaredField("_priorityScore"); + Field sequenceField = reqClass.getDeclaredField("_sequence"); + priorityField.setAccessible(true); + sequenceField.setAccessible(true); + + requests.sort((a, b) -> { + try { + double pa = priorityField.getDouble(a); + double pb = priorityField.getDouble(b); + int byPriority = Double.compare(pb, pa); + if (byPriority != 0) + return byPriority; + long sa = sequenceField.getLong(a); + long sb = sequenceField.getLong(b); + return Long.compare(sa, sb); + } + catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + } + List order = new ArrayList<>(); - for (Object obj : deque) { + for (Object obj : requests) { Field entriesField = obj.getClass().getDeclaredField("_entries"); entriesField.setAccessible(true); List entries = (List) entriesField.get(obj); @@ -249,8 +361,8 @@ public void completeRead(BlockKey key) { BlockEntry entry = _readEntries.get(key); if (entry == null) throw new IllegalStateException("No registered entry for " + key); - entry.setDataUnsafe(new Object()); + BlockEntryTestAccess.setDataUnsafe(entry, new Object()); future.complete(entry); } + } } -} diff --git a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java b/src/test/java/org/apache/sysds/test/component/ooc/cache/SourceBackedCacheSchedulerTest.java similarity index 87% rename from src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java rename to src/test/java/org/apache/sysds/test/component/ooc/cache/SourceBackedCacheSchedulerTest.java index 423c2b7f425..83c2fd59669 100644 --- a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedCacheSchedulerTest.java +++ b/src/test/java/org/apache/sysds/test/component/ooc/cache/SourceBackedCacheSchedulerTest.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.sysds.runtime.ooc.cache; +package org.apache.sysds.test.component.ooc.cache; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; @@ -25,10 +25,17 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.io.MatrixWriter; import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.ooc.cache.BlockEntry; +import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.BlockState; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.ooc.cache.OOCLRUCacheScheduler; +import org.apache.sysds.runtime.ooc.cache.OOCMatrixIOHandler; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -46,7 +53,7 @@ public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); handler = new OOCMatrixIOHandler(); - scheduler = new OOCLRUCacheScheduler(handler, 0, Long.MAX_VALUE); + scheduler = new OOCLRUCacheScheduler(handler, 0, Long.MAX_VALUE, 40000000); } @After @@ -79,11 +86,11 @@ public void testPutSourceBackedAndReload() throws Exception { BlockKey key = new BlockKey(11, 0); BlockEntry entry = scheduler.putAndPinSourceBacked(key, imv, ((MatrixBlock) imv.getValue()).getExactSerializedSize(), desc); - org.junit.Assert.assertEquals(BlockState.WARM, entry.getState()); + Assert.assertEquals(BlockState.WARM, entry.getState()); scheduler.unpin(entry); - org.junit.Assert.assertEquals(BlockState.COLD, entry.getState()); - org.junit.Assert.assertNull(entry.getDataUnsafe()); + Assert.assertEquals(BlockState.COLD, entry.getState()); + Assert.assertNull(BlockEntryTestAccess.getDataUnsafe(entry)); BlockEntry reloaded = scheduler.request(key).get(); IndexedMatrixValue reloadImv = (IndexedMatrixValue) reloaded.getData(); diff --git a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java b/src/test/java/org/apache/sysds/test/component/ooc/cache/SourceBackedReadOOCIOHandlerTest.java similarity index 85% rename from src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java rename to src/test/java/org/apache/sysds/test/component/ooc/cache/SourceBackedReadOOCIOHandlerTest.java index e688bf0f1c0..7c93af0ba09 100644 --- a/src/test/java/org/apache/sysds/runtime/ooc/cache/SourceBackedReadOOCIOHandlerTest.java +++ b/src/test/java/org/apache/sysds/test/component/ooc/cache/SourceBackedReadOOCIOHandlerTest.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.sysds.runtime.ooc.cache; +package org.apache.sysds.test.component.ooc.cache; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; @@ -25,10 +25,16 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.io.MatrixWriter; import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.ooc.cache.BlockEntry; +import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.BlockState; +import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; +import org.apache.sysds.runtime.ooc.cache.OOCMatrixIOHandler; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -69,17 +75,17 @@ public void testSourceBackedScheduleRead() throws Exception { rows, cols, blen, src.getNonZeros(), Long.MAX_VALUE, true, target); OOCIOHandler.SourceReadResult res = handler.scheduleSourceRead(req).get(); - org.junit.Assert.assertFalse(res.blocks.isEmpty()); + Assert.assertFalse(res.blocks.isEmpty()); OOCIOHandler.SourceBlockDescriptor desc = res.blocks.get(0); BlockKey key = new BlockKey(7, 0); handler.registerSourceLocation(key, desc); - BlockEntry entry = new BlockEntry(key, desc.serializedSize, null); - entry.setState(BlockState.COLD); + BlockEntry entry = BlockEntryTestAccess.newBlockEntry(key, desc.serializedSize, null); + BlockEntryTestAccess.setState(entry, BlockState.COLD); handler.scheduleRead(entry).get(); - IndexedMatrixValue imv = (IndexedMatrixValue) entry.getDataUnsafe(); + IndexedMatrixValue imv = (IndexedMatrixValue) BlockEntryTestAccess.getDataUnsafe(entry); MatrixBlock readBlock = (MatrixBlock) imv.getValue(); MatrixBlock expected = expectedBlock(src, desc.indexes, blen); TestUtils.compareMatrices(expected, readBlock, 1e-12); diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/KMeansTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/KMeansTest.java new file mode 100644 index 00000000000..4a70d7b25af --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/KMeansTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import java.io.IOException; +import java.util.Random; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class KMeansTest extends AutomatedTestBase { + private static final String TEST_NAME = "KMeans"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + KMeansTest.class.getSimpleName() + "/"; + + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME_OOC = "C"; + private static final String OUTPUT_NAME_CP = "C_target"; + + private static final int ROWS = 10000; + private static final int COLS = 400; + private static final int K = 8; + private static final int RUNS = 3; + private static final int MAX_ITER = 50; + private static final int SEED = 7; + private static final int BLOCK_SIZE = 1000; + private static final double MAX_VAL = 2; + private static final double SPARSITY_DENSE = 1.0; + private static final double SPARSITY_SPARSE = 0.2; + private static final double EPS = 1e-9; + private static final double CLUSTER_NOISE = 0.05; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testKMeansDense() { + runKMeansTest(false); + } + + @Test + public void testKMeansSparse() { + runKMeansTest(true); + } + + private void runKMeansTest(boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME); + + String home = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = home + TEST_NAME + ".dml"; + + double[][] xData = generateClusteredInput(sparse); + writeBinaryWithMTD(INPUT_NAME, DataConverter.convertToMatrixBlock(xData)); + + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", + input(INPUT_NAME), Integer.toString(K), Integer.toString(RUNS), Integer.toString(MAX_ITER), + Double.toString(EPS), Integer.toString(SEED), output(OUTPUT_NAME_OOC)}; + runTest(true, false, null, -1); + + programArgs = new String[] {"-explain", "-stats", "-args", + input(INPUT_NAME), Integer.toString(K), Integer.toString(RUNS), Integer.toString(MAX_ITER), + Double.toString(EPS), Integer.toString(SEED), output(OUTPUT_NAME_CP)}; + runTest(true, false, null, -1); + + MatrixBlock centersOOC = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_OOC), + Types.FileFormat.BINARY, K, COLS, BLOCK_SIZE); + MatrixBlock centersCP = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_CP), + Types.FileFormat.BINARY, K, COLS, BLOCK_SIZE); + + TestUtils.compareMatrices(centersOOC, centersCP, EPS); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } + + private static double[][] generateClusteredInput(boolean sparse) { + Random rand = new Random(SEED); + double[][] centers = new double[K][COLS]; + for(int k = 0; k < K; k++) { + for(int c = 0; c < COLS; c++) + centers[k][c] = rand.nextDouble() * MAX_VAL; + } + + double[][] data = new double[ROWS][COLS]; + double keepProb = sparse ? SPARSITY_SPARSE : SPARSITY_DENSE; + for(int r = 0; r < ROWS; r++) { + int cluster = r % K; + for(int c = 0; c < COLS; c++) { + double v = centers[cluster][c] + rand.nextGaussian() * CLUSTER_NOISE; + v = Math.max(0, Math.min(MAX_VAL, v)); + if(rand.nextDouble() > keepProb) + v = 0; + data[r][c] = v; + } + } + return data; + } +} diff --git a/src/test/scripts/functions/ooc/KMeans.dml b/src/test/scripts/functions/ooc/KMeans.dml new file mode 100644 index 00000000000..ca1ed96a0e4 --- /dev/null +++ b/src/test/scripts/functions/ooc/KMeans.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +k = as.integer($2); +runs = as.integer($3); +max_iter = as.integer($4); +seed = as.integer($6); + +[C, Y] = kmeans(X = X, k = k, runs = runs, max_iter = max_iter, eps = $5, seed = seed); +write(C, $7, format = "binary"); +print(sum(Y));