|
16 | 16 | */ |
17 | 17 | package org.apache.spark.sql.execution.python |
18 | 18 |
|
19 | | -import org.apache.spark.sql.GlutenSQLTestsTrait |
| 19 | +import org.apache.gluten.execution.FilterExecTransformerBase |
20 | 20 |
|
21 | | -class GlutenPythonDataSourceSuite extends PythonDataSourceSuite with GlutenSQLTestsTrait {} |
| 21 | +import org.apache.spark.sql.{GlutenSQLTestsTrait, IntegratedUDFTestUtils, Row} |
| 22 | +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec |
| 23 | +import org.apache.spark.sql.execution.datasources.v2.python.PythonScan |
| 24 | +import org.apache.spark.sql.internal.SQLConf |
| 25 | +import org.apache.spark.sql.types.StructType |
| 26 | + |
| 27 | +class GlutenPythonDataSourceSuite extends PythonDataSourceSuite with GlutenSQLTestsTrait { |
| 28 | + |
| 29 | + import IntegratedUDFTestUtils._ |
| 30 | + |
| 31 | + // Gluten replaces FilterExec with FilterExecTransformer and |
| 32 | + // BatchScanExec with BatchScanExecTransformer |
| 33 | + testGluten("data source reader with filter pushdown") { |
| 34 | + assume(shouldTestPandasUDFs) |
| 35 | + val dataSourceScript = |
| 36 | + s""" |
| 37 | + |from pyspark.sql.datasource import ( |
| 38 | + | DataSource, |
| 39 | + | DataSourceReader, |
| 40 | + | EqualTo, |
| 41 | + | InputPartition, |
| 42 | + |) |
| 43 | + | |
| 44 | + |class SimpleDataSourceReader(DataSourceReader): |
| 45 | + | def partitions(self): |
| 46 | + | return [InputPartition(i) for i in range(2)] |
| 47 | + | |
| 48 | + | def pushFilters(self, filters): |
| 49 | + | for filter in filters: |
| 50 | + | if filter != EqualTo(("partition",), 0): |
| 51 | + | yield filter |
| 52 | + | |
| 53 | + | def read(self, partition): |
| 54 | + | yield (0, partition.value) |
| 55 | + | yield (1, partition.value) |
| 56 | + | yield (2, partition.value) |
| 57 | + | |
| 58 | + |class SimpleDataSource(DataSource): |
| 59 | + | def schema(self): |
| 60 | + | return "id int, partition int" |
| 61 | + | |
| 62 | + | def reader(self, schema): |
| 63 | + | return SimpleDataSourceReader() |
| 64 | + |""".stripMargin |
| 65 | + val schema = StructType.fromDDL("id INT, partition INT") |
| 66 | + val dataSource = |
| 67 | + createUserDefinedPythonDataSource(name = dataSourceName, pythonScript = dataSourceScript) |
| 68 | + withSQLConf(SQLConf.PYTHON_FILTER_PUSHDOWN_ENABLED.key -> "true") { |
| 69 | + spark.dataSource.registerPython(dataSourceName, dataSource) |
| 70 | + val df = |
| 71 | + spark.read.format(dataSourceName).schema(schema).load().filter("id = 1 and partition = 0") |
| 72 | + val plan = df.queryExecution.executedPlan |
| 73 | + |
| 74 | + val filter = collectFirst(plan) { |
| 75 | + case s: FilterExecTransformerBase => |
| 76 | + val condition = s.cond.toString |
| 77 | + assert(!condition.contains("= 0")) |
| 78 | + assert(condition.contains("= 1")) |
| 79 | + s |
| 80 | + }.getOrElse( |
| 81 | + fail(s"FilterExecTransformerBase not found in the plan. Actual plan:\n$plan") |
| 82 | + ) |
| 83 | + |
| 84 | + // Gluten does not replace PythonScan's BatchScanExec - it stays as vanilla |
| 85 | + // BatchScanExec with RowToVeloxColumnar transition |
| 86 | + collectFirst(filter) { |
| 87 | + case s: BatchScanExec if s.scan.isInstanceOf[PythonScan] => |
| 88 | + val p = s.scan.asInstanceOf[PythonScan] |
| 89 | + assert(p.getMetaData().get("PushedFilters").contains("[EqualTo(partition,0)]")) |
| 90 | + }.getOrElse( |
| 91 | + fail(s"BatchScanExec with PythonScan not found. Actual plan:\n$plan") |
| 92 | + ) |
| 93 | + |
| 94 | + checkAnswer(df, Seq(Row(1, 0), Row(1, 1))) |
| 95 | + } |
| 96 | + } |
| 97 | +} |
0 commit comments