diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ac406a9fa694..4773486cb5cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -23,7 +23,9 @@ import java.util.regex.Pattern import scala.jdk.CollectionConverters._ import org.scalatest.Assertions +import org.scalatest.Suite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -35,7 +37,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ -abstract class QueryTest extends PlanTest with SparkSessionProvider { +trait QueryTestBase extends PlanTestBase with SparkSessionProvider { self: Suite => /** * Runs the plan and makes sure the answer contains all of the keywords. @@ -202,7 +204,8 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider { * Asserts that a given [[Dataset]] will be executed using the given number of cached results. */ def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { - val planWithCaching = query.queryExecution.withCachedData + val planWithCaching = + query.asInstanceOf[classic.Dataset[_]].queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached } @@ -218,7 +221,8 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider { * storage level. */ def assertCached(query: Dataset[_], cachedName: String, storageLevel: StorageLevel): Unit = { - val planWithCaching = query.queryExecution.withCachedData + val planWithCaching = + query.asInstanceOf[classic.Dataset[_]].queryExecution.withCachedData val matched = planWithCaching.exists { case cached: InMemoryRelation => val cacheBuilder = cached.cacheBuilder @@ -238,14 +242,19 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider { * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. */ def assertEmptyMissingInput(query: Dataset[_]): Unit = { - assert(query.queryExecution.analyzed.missingInput.isEmpty, - s"The analyzed logical plan has missing inputs:\n${query.queryExecution.analyzed}") - assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, - s"The optimized logical plan has missing inputs:\n${query.queryExecution.optimizedPlan}") - assert(query.queryExecution.executedPlan.missingInput.isEmpty, - s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") + val qe = query.asInstanceOf[classic.Dataset[_]].queryExecution + assert(qe.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs:\n${qe.analyzed}") + assert(qe.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs:\n${qe.optimizedPlan}") + assert(qe.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs:\n${qe.executedPlan}") } +} + +abstract class QueryTest extends SparkFunSuite with QueryTestBase { + protected def getCurrentClassCallSitePattern: String = { val cs = Thread.currentThread().getStackTrace()(2) s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)" @@ -434,7 +443,7 @@ object QueryTest extends Assertions { * @param expectedAnswer the expected result in a[[Row]]. * @param absTol the absolute tolerance between actual and expected answers. */ - protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double): Unit = { require(actualAnswer.length == expectedAnswer.length, s"actual answer length ${actualAnswer.length} != " + s"expected answer length ${expectedAnswer.length}") @@ -469,13 +478,14 @@ object QueryTest extends Assertions { } } - spark.sparkContext.listenerBus.waitUntilEmpty(15000) - spark.listenerManager.register(listener) + val classicSession = spark.asInstanceOf[classic.SparkSession] + classicSession.sparkContext.listenerBus.waitUntilEmpty(15000) + classicSession.listenerManager.register(listener) try { thunk - spark.sparkContext.listenerBus.waitUntilEmpty(15000) + classicSession.sparkContext.listenerBus.waitUntilEmpty(15000) } finally { - spark.listenerManager.unregister(listener) + classicSession.listenerManager.unregister(listener) } capturedQueryExecutions