Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {

// Extract object store options from first file (S3 configs apply to all files in scan)
var firstPartition: Option[PartitionedFile] = None
val filePartitions = scan.getFilePartitions()
val filePartitions = scan.planner.getFilePartitions()
firstPartition = filePartitions.flatMap(_.files.headOption).headOption

val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
Expand Down Expand Up @@ -205,6 +205,6 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging {
}

override def createExec(nativeOp: Operator, op: CometScanExec): CometNativeExec = {
CometNativeScanExec(nativeOp, op.wrapped, op.session, op)
CometNativeScanExec(nativeOp, op.wrapped, op.session, op.planner)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ case class CometNativeScanExec(
disableBucketedScan: Boolean = false,
originalPlan: FileSourceScanExec,
override val serializedPlanOpt: SerializedPlan,
@transient scan: CometScanExec, // Lazy access to file partitions without serializing with plan
@transient planner: FilePartitionPlanner,
sourceKey: String) // Key for PlanDataInjector to match common+partition data at runtime
extends CometLeafExec
with DataSourceScanExec
Expand Down Expand Up @@ -143,8 +143,9 @@ case class CometNativeScanExec(
// Extract common data from nativeOp
val commonBytes = nativeOp.getNativeScan.getCommon.toByteArray

// Get file partitions from CometScanExec (handles bucketing, etc.)
val filePartitions = scan.getFilePartitions()
// Get file partitions from FilePartitionPlanner (handles bucketing, etc.)
val filePartitions = planner.getFilePartitions()
planner.sendDriverMetrics(metrics, sparkContext)

// Serialize each partition's files
import org.apache.comet.serde.operator.partition2Proto
Expand Down Expand Up @@ -210,7 +211,7 @@ case class CometNativeScanExec(
disableBucketedScan,
originalPlan.doCanonicalize(),
SerializedPlan(None),
null, // Transient scan not needed for canonicalization
null, // Transient planner not needed for canonicalization
""
) // sourceKey not needed for canonicalization
}
Expand Down Expand Up @@ -246,7 +247,7 @@ case class CometNativeScanExec(
case Some(metric) => nativeMetrics + ("numOutputRows" -> metric)
case None => nativeMetrics
}
withAlias ++ scan.metrics.filterKeys(driverMetricKeys)
withAlias ++ originalPlan.driverMetrics.filterKeys(driverMetricKeys)
}

/**
Expand All @@ -260,7 +261,7 @@ object CometNativeScanExec {
nativeOp: Operator,
scanExec: FileSourceScanExec,
session: SparkSession,
scan: CometScanExec): CometNativeScanExec = {
planner: FilePartitionPlanner): CometNativeScanExec = {
// TreeNode.mapProductIterator is protected method.
def mapProductIterator[B: ClassTag](product: Product, f: Any => B): Array[B] = {
val arr = Array.ofDim[B](product.productArity)
Expand Down Expand Up @@ -310,7 +311,7 @@ object CometNativeScanExec {
wrapped.disableBucketedScan,
wrapped,
SerializedPlan(None),
scan,
planner,
sourceKey)
scanExec.logicalLink.foreach(batchScanExec.setLogicalLink)
batchScanExec
Expand Down
273 changes: 15 additions & 258 deletions spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,16 @@

package org.apache.spark.sql.comet

import scala.collection.mutable.HashMap
import scala.concurrent.duration.NANOSECONDS
import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.comet.shims.ShimCometScanExec
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
Expand Down Expand Up @@ -83,62 +78,20 @@ case class CometScanExec(

override def vectorTypes: Option[Seq[String]] = wrapped.vectorTypes

private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty

/**
* Send the driver-side metrics. Before calling this function, selectedPartitions has been
* initialized. See SPARK-26327 for more details.
*/
private def sendDriverMetrics(): Unit = {
driverMetrics.foreach(e => metrics(e._1).add(e._2))
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(
sparkContext,
executionId,
metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
}
@transient lazy val planner: FilePartitionPlanner = new FilePartitionPlanner(
relation,
requiredSchema,
partitionFilters,
dataFilters,
optionalBucketSet,
optionalNumCoalescedBuckets,
bucketedScan)

private def isDynamicPruningFilter(e: Expression): Boolean =
e.find(_.isInstanceOf[PlanExpression[_]]).isDefined

@transient lazy val selectedPartitions: Array[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
val ret =
relation.location.listFiles(partitionFilters.filterNot(isDynamicPruningFilter), dataFilters)
setFilesNumAndSizeMetric(ret, true)
val timeTakenMs =
NANOSECONDS.toMillis((System.nanoTime() - startTime) + optimizerMetadataTimeNs)
driverMetrics("metadataTime") = timeTakenMs
ret
}.toArray

// We can only determine the actual partitions at runtime when a dynamic partition filter is
// present. This is because such a filter relies on information that is only available at run
// time (for instance the keys used in the other side of a join).
@transient private lazy val dynamicallySelectedPartitions: Array[PartitionDirectory] = {
val dynamicPartitionFilters = partitionFilters.filter(isDynamicPruningFilter)

if (dynamicPartitionFilters.nonEmpty) {
val startTime = System.nanoTime()
// call the file index for the files matching all filters except dynamic partition filters
val predicate = dynamicPartitionFilters.reduce(And)
val partitionColumns = relation.partitionSchema
val boundPredicate = Predicate.create(
predicate.transform { case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
},
Nil)
val ret = selectedPartitions.filter(p => boundPredicate.eval(p.values))
setFilesNumAndSizeMetric(ret, false)
val timeTakenMs = (System.nanoTime() - startTime) / 1000 / 1000
driverMetrics("pruningTime") = timeTakenMs
ret
} else {
selectedPartitions
}
}
@transient lazy val selectedPartitions: Array[PartitionDirectory] =
planner.selectedPartitions

// exposed for testing
lazy val bucketedScan: Boolean = wrapped.bucketedScan
Expand Down Expand Up @@ -214,41 +167,14 @@ case class CometScanExec(
hadoopConf =
relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))

val readRDD = if (bucketedScan) {
createBucketedReadRDD(
relation.bucketSpec.get,
readFile,
dynamicallySelectedPartitions,
relation)
} else {
createReadRDD(readFile, dynamicallySelectedPartitions, relation)
}
sendDriverMetrics()
readRDD
val filePartitions = getFilePartitions()
prepareRDD(relation, readFile, filePartitions)
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
inputRDD :: Nil
}

/** Helper for computing total number and size of files in selected partitions. */
private def setFilesNumAndSizeMetric(
partitions: Seq[PartitionDirectory],
static: Boolean): Unit = {
val filesNum = partitions.map(_.files.size.toLong).sum
val filesSize = partitions.map(_.files.map(_.getLen).sum).sum
if (!static || !partitionFilters.exists(isDynamicPruningFilter)) {
driverMetrics("numFiles") = filesNum
driverMetrics("filesSize") = filesSize
} else {
driverMetrics("staticFilesNum") = filesNum
driverMetrics("staticFilesSize") = filesSize
}
if (relation.partitionSchema.nonEmpty) {
driverMetrics("numPartitions") = partitions.length
}
}

override lazy val metrics: Map[String, SQLMetric] =
wrapped.driverMetrics ++ CometMetricNode.baseScanMetrics(
session.sparkContext) ++ (relation.fileFormat match {
Expand Down Expand Up @@ -296,178 +222,9 @@ case class CometScanExec(
* for native scans that only need partition metadata.
*/
def getFilePartitions(): Seq[FilePartition] = {
val filePartitions = if (bucketedScan) {
createFilePartitionsForBucketedScan(
relation.bucketSpec.get,
dynamicallySelectedPartitions,
relation)
} else {
createFilePartitionsForNonBucketedScan(dynamicallySelectedPartitions, relation)
}
sendDriverMetrics()
filePartitions
}

/**
* Create file partitions for bucketed scans without instantiating readers.
*
* @param bucketSpec
* the bucketing spec.
* @param selectedPartitions
* Hive-style partition that are part of the read.
* @param fsRelation
* [[HadoopFsRelation]] associated with the read.
*/
private def createFilePartitionsForBucketedScan(
bucketSpec: BucketSpec,
selectedPartitions: Array[PartitionDirectory],
fsRelation: HadoopFsRelation): Seq[FilePartition] = {
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
val filesGroupedToBuckets =
selectedPartitions
.flatMap { p =>
p.files.map { f =>
getPartitionedFile(f, p)
}
}
.groupBy { f =>
BucketingUtils
.getBucketId(new Path(f.filePath.toString()).getName)
.getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath.toString()))
}

val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
filesGroupedToBuckets.filter { f =>
bucketSet.get(f._1)
}
} else {
filesGroupedToBuckets
}

optionalNumCoalescedBuckets
.map { numCoalescedBuckets =>
logInfo(s"Coalescing to ${numCoalescedBuckets} buckets")
val coalescedBuckets = prunedFilesGroupedToBuckets.groupBy(_._1 % numCoalescedBuckets)
Seq.tabulate(numCoalescedBuckets) { bucketId =>
val partitionedFiles = coalescedBuckets
.get(bucketId)
.map {
_.values.flatten.toArray
}
.getOrElse(Array.empty)
FilePartition(bucketId, partitionedFiles)
}
}
.getOrElse {
Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty))
}
}
}

/**
* Create file partitions for non-bucketed scans without instantiating readers.
*
* @param selectedPartitions
* Hive-style partition that are part of the read.
* @param fsRelation
* [[HadoopFsRelation]] associated with the read.
*/
private def createFilePartitionsForNonBucketedScan(
selectedPartitions: Array[PartitionDirectory],
fsRelation: HadoopFsRelation): Seq[FilePartition] = {
val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
val maxSplitBytes =
FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions)
logInfo(
s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")

// Filter files with bucket pruning if possible
val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled
val shouldProcess: Path => Boolean = optionalBucketSet match {
case Some(bucketSet) if bucketingEnabled =>
// Do not prune the file if bucket file name is invalid
filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get)
case _ =>
_ => true
}

val splitFiles = selectedPartitions
.flatMap { partition =>
partition.files.flatMap { file =>
// getPath() is very expensive so we only want to call it once in this block:
val filePath = file.getPath

if (shouldProcess(filePath)) {
val isSplitable = relation.fileFormat.isSplitable(
relation.sparkSession,
relation.options,
filePath) &&
// SPARK-39634: Allow file splitting in combination with row index generation once
// the fix for PARQUET-2161 is available.
!isNeededForSchema(requiredSchema)
super.splitFiles(
sparkSession = relation.sparkSession,
file = file,
filePath = filePath,
isSplitable = isSplitable,
maxSplitBytes = maxSplitBytes,
partitionValues = partition.values)
} else {
Seq.empty
}
}
}
.sortBy(_.length)(implicitly[Ordering[Long]].reverse)

FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)
}

/**
* Create an RDD for bucketed reads. The non-bucketed variant of this function is
* [[createReadRDD]].
*
* Each RDD partition being returned should include all the files with the same bucket id from
* all the given Hive partitions.
*
* @param bucketSpec
* the bucketing spec.
* @param readFile
* a function to read each (part of a) file.
* @param selectedPartitions
* Hive-style partition that are part of the read.
* @param fsRelation
* [[HadoopFsRelation]] associated with the read.
*/
private def createBucketedReadRDD(
bucketSpec: BucketSpec,
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Array[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
val filePartitions =
createFilePartitionsForBucketedScan(bucketSpec, selectedPartitions, fsRelation)
prepareRDD(fsRelation, readFile, filePartitions)
}

/**
* Create an RDD for non-bucketed reads. The bucketed variant of this function is
* [[createBucketedReadRDD]].
*
* @param readFile
* a function to read each (part of a) file.
* @param selectedPartitions
* Hive-style partition that are part of the read.
* @param fsRelation
* [[HadoopFsRelation]] associated with the read.
*/
private def createReadRDD(
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Array[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
val filePartitions = createFilePartitionsForNonBucketedScan(selectedPartitions, fsRelation)
prepareRDD(fsRelation, readFile, filePartitions)
val result = planner.getFilePartitions()
planner.sendDriverMetrics(metrics, sparkContext)
result
}

private def prepareRDD(
Expand Down
Loading
Loading