Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions datafusion-flight-sql-server/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,21 @@ impl ArrowFlightSqlService for FlightSqlService {
sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => {
// print!("Query: {query}\n");

let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?;
let arrow_schema = stream.schema();
// Declare the logical-plan schema (the same schema that
// GetFlightInfo advertises) rather than the physical stream
// schema. The two can differ in field nullability (e.g. for
// aggregates), and strict clients such as the ADBC Flight SQL
// driver reject a DoGet stream whose schema does not match
// the advertised FlightInfo schema exactly.
let plan = ctx
.sql_to_logical_plan(&query)
.await
.map_err(df_error_to_status)?;
let arrow_schema = get_schema_for_plan(&plan, self.config.schema_with_metadata);
let stream = ctx
.execute_logical_plan(plan)
.await
.map_err(df_error_to_status)?;
let arrow_stream = stream.map(|i| {
let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
Ok(batch)
Expand Down Expand Up @@ -253,11 +266,14 @@ impl ArrowFlightSqlService for FlightSqlService {
.map_err(df_error_to_status)?;
}

// Same schema consistency requirement as for
// CommandStatementQuery above.
let arrow_schema =
get_schema_for_plan(&plan, self.config.schema_with_metadata);
let stream = ctx
.execute_logical_plan(plan)
.await
.map_err(df_error_to_status)?;
let arrow_schema = stream.schema();
let arrow_stream = stream.map(|i| {
let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
Ok(batch)
Expand All @@ -283,11 +299,14 @@ impl ArrowFlightSqlService for FlightSqlService {

let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?;

// Same schema consistency requirement as for
// CommandStatementQuery above.
let arrow_schema =
get_schema_for_plan(&plan, self.config.schema_with_metadata);
let state = ctx.inner.state();
let df = DataFrame::new(state, plan);

let stream = df.execute_stream().await.map_err(df_error_to_status)?;
let arrow_schema = stream.schema();
let arrow_stream = stream.map(|i| {
let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
Ok(batch)
Expand Down
89 changes: 89 additions & 0 deletions datafusion-flight-sql-server/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,92 @@ async fn test_query_with_join() {
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 4, "Should have 4 rows from join");
}

#[tokio::test]
async fn test_do_get_schema_matches_advertised_flight_info_schema() {
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::prelude::ParquetReadOptions;

// Aggregates over statistics-backed sources (e.g. Parquet) are a case
// where the logical plan schema (advertised in FlightInfo) and the
// physical stream schema disagree on field nullability: the optimizer
// rewrites MIN() into a non-nullable literal taken from the file
// statistics, while the logical schema keeps MIN() nullable. Strict
// clients (e.g. the ADBC Flight SQL driver) reject the DoGet stream
// if its schema does not exactly match the one advertised in
// FlightInfo.
let schema = Arc::new(Schema::new(vec![Field::new(
"amount",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![50, 75, 100, 25]))],
)
.unwrap();
let path = std::env::temp_dir().join(format!(
"flight_sql_schema_test_{}.parquet",
std::process::id()
));
let file = std::fs::File::create(&path).expect("create parquet file");
let mut writer = ArrowWriter::try_new(file, schema, None).expect("create writer");
writer.write(&batch).expect("write batch");
writer.close().expect("close writer");

let ctx = SessionContext::new();
ctx.register_parquet(
"orders_pq",
path.to_str().expect("utf-8 path"),
ParquetReadOptions::default(),
)
.await
.expect("register parquet");

let addr = "0.0.0.0:50069";
start_test_server(addr.to_string(), ctx.state()).await;

let mut client = create_test_client(&format!("http://{}", addr)).await;

let flight_info = client
.execute(
"SELECT MIN(amount) AS lo, COUNT(*) AS n FROM orders_pq".to_string(),
None,
)
.await
.expect("Query should succeed");

let advertised = flight_info
.clone()
.try_decode_schema()
.expect("FlightInfo should carry a schema");

let ticket = flight_info
.endpoint
.first()
.expect("Should have endpoint")
.ticket
.clone()
.expect("Should have ticket");

let mut stream = client.do_get(ticket).await.expect("do_get should succeed");
while stream
.try_next()
.await
.expect("Stream should work")
.is_some()
{}

let actual = stream
.schema()
.expect("DoGet stream should declare a schema")
.clone();

std::fs::remove_file(&path).ok();

assert_eq!(
advertised,
*actual,
"DoGet stream schema must match the schema advertised in FlightInfo"
);
}
Loading