Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public final class CometShuffleExternalSorterAsync
private final LinkedList<SpillInfo> spills = new LinkedList<>();

/** Peak memory used by this sorter so far, in bytes. */
private long peakMemoryUsedBytes;
private volatile long peakMemoryUsedBytes;

// Checksum calculator for each partition. Empty when shuffle checksum disabled.
private final long[] partitionChecksums;
Expand Down Expand Up @@ -152,8 +152,16 @@ public CometShuffleExternalSorterAsync(
this.tracingEnabled = (boolean) CometConf$.MODULE$.COMET_TRACING_ENABLED().get();

this.threadNum = (int) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_ASYNC_THREAD_NUM().get();
assert (this.threadNum > 0);
if (this.threadNum <= 0) {
throw new IllegalArgumentException(
"spark.comet.columnar.shuffle.async.thread.num must be positive, got: " + this.threadNum);
}
this.threadPool = ShuffleThreadPool.getThreadPool();
if (this.threadPool == null) {
throw new IllegalStateException(
"Async shuffle thread pool is not initialized. "
+ "Ensure spark.comet.columnar.shuffle.async.enabled is true.");
}

this.preferDictionaryRatio =
(double) CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get();
Expand Down Expand Up @@ -215,10 +223,21 @@ public void spill() throws IOException {
SpillSorter spillingSorter = activeSpillSorter;
Callable<Void> task =
() -> {
spillingSorter.writeSortedFileNative(false, tracingEnabled);
final long spillSize = spillingSorter.freeMemory();
spillingSorter.freeArray();
spillingSorters.remove(spillingSorter);
long spillSize = 0;
try {
spillingSorter.writeSortedFileNative(false, tracingEnabled);
spillSize = spillingSorter.freeMemory();
} finally {
// Ensure cleanup happens even if writeSortedFileNative() throws.
// freeMemory() may have already been called above, but it's safe to call again
// (returns 0 if already freed). freeArray() must be called to release the pointer
// array.
if (spillSize == 0) {
spillSize = spillingSorter.freeMemory();
}
spillingSorter.freeArray();
spillingSorters.remove(spillingSorter);
}

// Reset the in-memory sorter's pointer array only after freeing up the memory pages
// holding the records. Otherwise, if the task is over allocated memory, then without
Expand All @@ -233,11 +252,20 @@ public void spill() throws IOException {
spillingSorters.add(spillingSorter);
asyncSpillTasks.add(threadPool.submit(task));

while (asyncSpillTasks.size() == threadNum) {
for (Future<Void> spillingTask : asyncSpillTasks) {
if (spillingTask.isDone()) {
asyncSpillTasks.remove(spillingTask);
break;
// If we've reached the max concurrent spill tasks, block until one completes.
// This provides backpressure to avoid unbounded memory growth.
while (asyncSpillTasks.size() >= threadNum) {
Future<Void> oldestTask = asyncSpillTasks.peek();
if (oldestTask != null) {
try {
oldestTask.get(); // Block until the oldest task completes
asyncSpillTasks.poll(); // Remove the completed task
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException("Interrupted while waiting for spill task", e);
} catch (ExecutionException e) {
asyncSpillTasks.poll(); // Remove the failed task
throw new IOException("Async spill task failed", e.getCause());
}
}
}
Expand Down Expand Up @@ -288,6 +316,23 @@ private long freeMemory() {
/** Force all memory and spill files to be deleted; called by shuffle error-handling code. */
@Override
public void cleanupResources() {
// Cancel any pending async spill tasks to stop background work.
// The tasks have try-finally blocks that will clean up their SpillSorter resources.
for (Future<Void> task : asyncSpillTasks) {
task.cancel(true);
}

// Wait briefly for cancelled tasks to complete their cleanup.
// This ensures SpillSorters are removed from spillingSorters before we iterate it.
for (Future<Void> task : asyncSpillTasks) {
try {
task.get(100, TimeUnit.MILLISECONDS);
} catch (Exception e) {
// Ignore - task was cancelled or failed, we're cleaning up anyway
}
}
asyncSpillTasks.clear();

freeMemory();

for (SpillInfo spill : spills) {
Expand Down Expand Up @@ -383,23 +428,38 @@ public SpillInfo[] closeAndGetSpills() throws IOException {
final TempShuffleBlockId blockId = spilledFileInfo._1();
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);

// Waits for all async tasks to finish.
// Waits for all async tasks to finish, collecting any exceptions.
// We wait for all tasks even if some fail to ensure proper cleanup.
IOException firstException = null;
for (Future<Void> task : asyncSpillTasks) {
try {
task.get();
} catch (Exception e) {
throw new IOException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
if (firstException == null) {
firstException = new IOException("Interrupted while waiting for spill tasks", e);
}
} catch (ExecutionException e) {
if (firstException == null) {
firstException = new IOException("Async spill task failed", e.getCause());
} else {
firstException.addSuppressed(e.getCause());
}
}
}

asyncSpillTasks.clear();

if (firstException != null) {
throw firstException;
}

activeSpillSorter.setSpillInfo(spillInfo);
activeSpillSorter.writeSortedFileNative(true, tracingEnabled);

freeMemory();
}

return spills.toArray(new SpillInfo[spills.size()]);
return spills.toArray(new SpillInfo[0]);
}
}
Loading