Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import java.util.regex.Pattern
import scala.jdk.CollectionConverters._

import org.scalatest.Assertions
import org.scalatest.Suite

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
Expand All @@ -35,7 +37,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._


abstract class QueryTest extends PlanTest with SparkSessionProvider {
trait QueryTestBase extends PlanTestBase with SparkSessionProvider { self: Suite =>

/**
* Runs the plan and makes sure the answer contains all of the keywords.
Expand Down Expand Up @@ -202,7 +204,8 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider {
* Asserts that a given [[Dataset]] will be executed using the given number of cached results.
*/
def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val planWithCaching =
query.asInstanceOf[classic.Dataset[_]].queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
}
Expand All @@ -218,7 +221,8 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider {
* storage level.
*/
def assertCached(query: Dataset[_], cachedName: String, storageLevel: StorageLevel): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val planWithCaching =
query.asInstanceOf[classic.Dataset[_]].queryExecution.withCachedData
val matched = planWithCaching.exists {
case cached: InMemoryRelation =>
val cacheBuilder = cached.cacheBuilder
Expand All @@ -238,14 +242,19 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider {
* Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
*/
def assertEmptyMissingInput(query: Dataset[_]): Unit = {
assert(query.queryExecution.analyzed.missingInput.isEmpty,
s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}")
assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}")
assert(query.queryExecution.executedPlan.missingInput.isEmpty,
s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}")
val qe = query.asInstanceOf[classic.Dataset[_]].queryExecution
assert(qe.analyzed.missingInput.isEmpty,
s"The analyzed logical plan has missing inputs:\n${qe.analyzed}")
assert(qe.optimizedPlan.missingInput.isEmpty,
s"The optimized logical plan has missing inputs:\n${qe.optimizedPlan}")
assert(qe.executedPlan.missingInput.isEmpty,
s"The physical plan has missing inputs:\n${qe.executedPlan}")
}

}

abstract class QueryTest extends SparkFunSuite with QueryTestBase {

protected def getCurrentClassCallSitePattern: String = {
val cs = Thread.currentThread().getStackTrace()(2)
s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)"
Expand Down Expand Up @@ -434,7 +443,7 @@ object QueryTest extends Assertions {
* @param expectedAnswer the expected result in a[[Row]].
* @param absTol the absolute tolerance between actual and expected answers.
*/
protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = {
def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double): Unit = {
require(actualAnswer.length == expectedAnswer.length,
s"actual answer length ${actualAnswer.length} != " +
s"expected answer length ${expectedAnswer.length}")
Expand Down Expand Up @@ -469,13 +478,14 @@ object QueryTest extends Assertions {
}
}

spark.sparkContext.listenerBus.waitUntilEmpty(15000)
spark.listenerManager.register(listener)
val classicSession = spark.asInstanceOf[classic.SparkSession]
classicSession.sparkContext.listenerBus.waitUntilEmpty(15000)
classicSession.listenerManager.register(listener)
try {
thunk
spark.sparkContext.listenerBus.waitUntilEmpty(15000)
classicSession.sparkContext.listenerBus.waitUntilEmpty(15000)
} finally {
spark.listenerManager.unregister(listener)
classicSession.listenerManager.unregister(listener)
}

capturedQueryExecutions
Expand Down