Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dev/release/comet-rm/workdir
spark/benchmarks
.DS_Store
comet-event-trace.json
__pycache__
111 changes: 96 additions & 15 deletions benchmarks/pyspark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ specific language governing permissions and limitations
under the License.
-->

# Shuffle Size Comparison Benchmark
# PySpark Benchmarks

Compares shuffle file sizes between Spark, Comet JVM, and Comet Native shuffle implementations.
A suite of PySpark benchmarks for comparing performance between Spark, Comet JVM, and Comet Native implementations.

## Available Benchmarks

Run `python run_benchmark.py --list-benchmarks` to see all available benchmarks:

- **shuffle-hash** - Shuffle all columns using hash partitioning on group_key
- **shuffle-roundrobin** - Shuffle all columns using round-robin partitioning

## Prerequisites

Expand Down Expand Up @@ -56,42 +63,116 @@ spark-submit \
| `--rows`, `-r` | 10000000 | Number of rows |
| `--partitions`, `-p` | 200 | Number of output partitions |

## Step 2: Run Benchmark
## Step 2: Run Benchmarks

Run benchmarks and check Spark UI for shuffle sizes:
### List Available Benchmarks

```bash
SPARK_MASTER=spark://master:7077 \
EXECUTOR_MEMORY=16g \
./run_all_benchmarks.sh /tmp/shuffle-benchmark-data
python run_benchmark.py --list-benchmarks
```

Or run individual modes:
### Run Individual Benchmarks

You can run specific benchmarks by name:

```bash
# Spark baseline
# Hash partitioning shuffle - Spark baseline
spark-submit --master spark://master:7077 \
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark --benchmark shuffle-hash

# Comet JVM shuffle
# Round-robin shuffle - Spark baseline
spark-submit --master spark://master:7077 \
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark --benchmark shuffle-roundrobin

# Hash partitioning - Comet JVM shuffle
spark-submit --master spark://master:7077 \
--jars /path/to/comet.jar \
--conf spark.comet.enabled=true \
--conf spark.comet.exec.shuffle.enabled=true \
--conf spark.comet.shuffle.mode=jvm \
--conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode jvm
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode jvm --benchmark shuffle-hash

# Comet Native shuffle
# Round-robin - Comet Native shuffle
spark-submit --master spark://master:7077 \
--jars /path/to/comet.jar \
--conf spark.comet.enabled=true \
--conf spark.comet.exec.shuffle.enabled=true \
--conf spark.comet.shuffle.mode=native \
--conf spark.comet.exec.shuffle.mode=native \
--conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode native
run_benchmark.py --data /tmp/shuffle-benchmark-data --mode native --benchmark shuffle-roundrobin
```

### Run All Benchmarks

Use the provided script to run all benchmarks across all modes:

```bash
SPARK_MASTER=spark://master:7077 \
EXECUTOR_MEMORY=16g \
./run_all_benchmarks.sh /tmp/shuffle-benchmark-data
```

## Checking Results

Open the Spark UI (default: http://localhost:4040) during each benchmark run to compare shuffle write sizes in the Stages tab.

## Adding New Benchmarks

The benchmark framework makes it easy to add new benchmarks:

1. **Create a benchmark class** in `benchmarks/` directory (or add to existing file):

```python
from benchmarks.base import Benchmark

class MyBenchmark(Benchmark):
@classmethod
def name(cls) -> str:
return "my-benchmark"

@classmethod
def description(cls) -> str:
return "Description of what this benchmark does"

def run(self) -> Dict[str, Any]:
# Read data
df = self.spark.read.parquet(self.data_path)

# Run your benchmark operation
def benchmark_operation():
result = df.filter(...).groupBy(...).agg(...)
result.write.mode("overwrite").parquet("/tmp/output")

# Time it
duration_ms = self._time_operation(benchmark_operation)

return {
'duration_ms': duration_ms,
# Add any other metrics you want to track
}
```

2. **Register the benchmark** in `benchmarks/__init__.py`:

```python
from .my_module import MyBenchmark

_BENCHMARK_REGISTRY = {
# ... existing benchmarks
MyBenchmark.name(): MyBenchmark,
}
```

3. **Run your new benchmark**:

```bash
python run_benchmark.py --data /path/to/data --mode spark --benchmark my-benchmark
```

The base `Benchmark` class provides:

- Automatic timing via `_time_operation()`
- Standard output formatting via `execute_timed()`
- Access to SparkSession, data path, and mode
- Spark configuration printing
79 changes: 79 additions & 0 deletions benchmarks/pyspark/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Benchmark registry for PySpark benchmarks.

This module provides a central registry for discovering and running benchmarks.
"""

from typing import Dict, Type, List

from .base import Benchmark
from .shuffle import ShuffleHashBenchmark, ShuffleRoundRobinBenchmark


# Registry of all available benchmarks
_BENCHMARK_REGISTRY: Dict[str, Type[Benchmark]] = {
ShuffleHashBenchmark.name(): ShuffleHashBenchmark,
ShuffleRoundRobinBenchmark.name(): ShuffleRoundRobinBenchmark,
}


def get_benchmark(name: str) -> Type[Benchmark]:
"""
Get a benchmark class by name.

Args:
name: Benchmark name

Returns:
Benchmark class

Raises:
KeyError: If benchmark name is not found
"""
if name not in _BENCHMARK_REGISTRY:
available = ", ".join(sorted(_BENCHMARK_REGISTRY.keys()))
raise KeyError(
f"Unknown benchmark: {name}. Available benchmarks: {available}"
)
return _BENCHMARK_REGISTRY[name]


def list_benchmarks() -> List[tuple[str, str]]:
"""
List all available benchmarks.

Returns:
List of (name, description) tuples
"""
benchmarks = []
for name in sorted(_BENCHMARK_REGISTRY.keys()):
benchmark_cls = _BENCHMARK_REGISTRY[name]
benchmarks.append((name, benchmark_cls.description()))
return benchmarks


__all__ = [
'Benchmark',
'get_benchmark',
'list_benchmarks',
'ShuffleHashBenchmark',
'ShuffleRoundRobinBenchmark',
]
127 changes: 127 additions & 0 deletions benchmarks/pyspark/benchmarks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Base benchmark class providing common functionality for all benchmarks.
"""

import time
from abc import ABC, abstractmethod
from typing import Dict, Any

from pyspark.sql import SparkSession


class Benchmark(ABC):
"""Base class for all PySpark benchmarks."""

def __init__(self, spark: SparkSession, data_path: str, mode: str):
"""
Initialize benchmark.

Args:
spark: SparkSession instance
data_path: Path to input data
mode: Execution mode (spark, jvm, native)
"""
self.spark = spark
self.data_path = data_path
self.mode = mode

@classmethod
@abstractmethod
def name(cls) -> str:
"""Return the benchmark name (used for CLI)."""
pass

@classmethod
@abstractmethod
def description(cls) -> str:
"""Return a short description of the benchmark."""
pass

@abstractmethod
def run(self) -> Dict[str, Any]:
"""
Run the benchmark and return results.

Returns:
Dictionary containing benchmark results (must include 'duration_ms')
"""
pass

def execute_timed(self) -> Dict[str, Any]:
"""
Execute the benchmark with timing and standard output.

Returns:
Dictionary containing benchmark results
"""
print(f"\n{'=' * 80}")
print(f"Benchmark: {self.name()}")
print(f"Mode: {self.mode.upper()}")
print(f"{'=' * 80}")
print(f"Data path: {self.data_path}")

# Print relevant Spark configuration
self._print_spark_config()

# Clear cache before running
self.spark.catalog.clearCache()

# Run the benchmark
print(f"\nRunning benchmark...")
results = self.run()

# Print results
print(f"\nDuration: {results['duration_ms']:,} ms")
if 'row_count' in results:
print(f"Rows processed: {results['row_count']:,}")

# Print any additional metrics
for key, value in results.items():
if key not in ['duration_ms', 'row_count']:
print(f"{key}: {value}")

print(f"{'=' * 80}\n")

return results

def _print_spark_config(self):
"""Print relevant Spark configuration."""
conf = self.spark.sparkContext.getConf()
print(f"Shuffle manager: {conf.get('spark.shuffle.manager', 'default')}")
print(f"Comet enabled: {conf.get('spark.comet.enabled', 'false')}")
print(f"Comet shuffle enabled: {conf.get('spark.comet.exec.shuffle.enabled', 'false')}")
print(f"Comet shuffle mode: {conf.get('spark.comet.shuffle.mode', 'not set')}")
print(f"Spark UI: {self.spark.sparkContext.uiWebUrl}")

def _time_operation(self, operation_fn):
"""
Time an operation and return duration in milliseconds.

Args:
operation_fn: Function to time (takes no arguments)

Returns:
Duration in milliseconds
"""
start_time = time.time()
operation_fn()
duration_ms = int((time.time() - start_time) * 1000)
return duration_ms
Loading
Loading