diff --git a/datafusion-flight-sql-server/src/service.rs b/datafusion-flight-sql-server/src/service.rs index dbaebc4..24633c1 100644 --- a/datafusion-flight-sql-server/src/service.rs +++ b/datafusion-flight-sql-server/src/service.rs @@ -220,8 +220,21 @@ impl ArrowFlightSqlService for FlightSqlService { sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => { // print!("Query: {query}\n"); - let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?; - let arrow_schema = stream.schema(); + // Declare the logical-plan schema (the same schema that + // GetFlightInfo advertises) rather than the physical stream + // schema. The two can differ in field nullability (e.g. for + // aggregates), and strict clients such as the ADBC Flight SQL + // driver reject a DoGet stream whose schema does not match + // the advertised FlightInfo schema exactly. + let plan = ctx + .sql_to_logical_plan(&query) + .await + .map_err(df_error_to_status)?; + let arrow_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata); + let stream = ctx + .execute_logical_plan(plan) + .await + .map_err(df_error_to_status)?; let arrow_stream = stream.map(|i| { let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?; Ok(batch) @@ -253,11 +266,14 @@ impl ArrowFlightSqlService for FlightSqlService { .map_err(df_error_to_status)?; } + // Same schema consistency requirement as for + // CommandStatementQuery above. + let arrow_schema = + get_schema_for_plan(&plan, self.config.schema_with_metadata); let stream = ctx .execute_logical_plan(plan) .await .map_err(df_error_to_status)?; - let arrow_schema = stream.schema(); let arrow_stream = stream.map(|i| { let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?; Ok(batch) @@ -283,11 +299,14 @@ impl ArrowFlightSqlService for FlightSqlService { let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?; + // Same schema consistency requirement as for + // CommandStatementQuery above. + let arrow_schema = + get_schema_for_plan(&plan, self.config.schema_with_metadata); let state = ctx.inner.state(); let df = DataFrame::new(state, plan); let stream = df.execute_stream().await.map_err(df_error_to_status)?; - let arrow_schema = stream.schema(); let arrow_stream = stream.map(|i| { let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?; Ok(batch) diff --git a/datafusion-flight-sql-server/tests/integration_test.rs b/datafusion-flight-sql-server/tests/integration_test.rs index 72b9f00..1870d5f 100644 --- a/datafusion-flight-sql-server/tests/integration_test.rs +++ b/datafusion-flight-sql-server/tests/integration_test.rs @@ -326,3 +326,92 @@ async fn test_query_with_join() { let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!(total_rows, 4, "Should have 4 rows from join"); } + +#[tokio::test] +async fn test_do_get_schema_matches_advertised_flight_info_schema() { + use datafusion::parquet::arrow::ArrowWriter; + use datafusion::prelude::ParquetReadOptions; + + // Aggregates over statistics-backed sources (e.g. Parquet) are a case + // where the logical plan schema (advertised in FlightInfo) and the + // physical stream schema disagree on field nullability: the optimizer + // rewrites MIN() into a non-nullable literal taken from the file + // statistics, while the logical schema keeps MIN() nullable. Strict + // clients (e.g. the ADBC Flight SQL driver) reject the DoGet stream + // if its schema does not exactly match the one advertised in + // FlightInfo. + let schema = Arc::new(Schema::new(vec![Field::new( + "amount", + DataType::Int32, + false, + )])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![50, 75, 100, 25]))], + ) + .unwrap(); + let path = std::env::temp_dir().join(format!( + "flight_sql_schema_test_{}.parquet", + std::process::id() + )); + let file = std::fs::File::create(&path).expect("create parquet file"); + let mut writer = ArrowWriter::try_new(file, schema, None).expect("create writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let ctx = SessionContext::new(); + ctx.register_parquet( + "orders_pq", + path.to_str().expect("utf-8 path"), + ParquetReadOptions::default(), + ) + .await + .expect("register parquet"); + + let addr = "0.0.0.0:50069"; + start_test_server(addr.to_string(), ctx.state()).await; + + let mut client = create_test_client(&format!("http://{}", addr)).await; + + let flight_info = client + .execute( + "SELECT MIN(amount) AS lo, COUNT(*) AS n FROM orders_pq".to_string(), + None, + ) + .await + .expect("Query should succeed"); + + let advertised = flight_info + .clone() + .try_decode_schema() + .expect("FlightInfo should carry a schema"); + + let ticket = flight_info + .endpoint + .first() + .expect("Should have endpoint") + .ticket + .clone() + .expect("Should have ticket"); + + let mut stream = client.do_get(ticket).await.expect("do_get should succeed"); + while stream + .try_next() + .await + .expect("Stream should work") + .is_some() + {} + + let actual = stream + .schema() + .expect("DoGet stream should declare a schema") + .clone(); + + std::fs::remove_file(&path).ok(); + + assert_eq!( + advertised, + *actual, + "DoGet stream schema must match the schema advertised in FlightInfo" + ); +}