Skip to content

Commit bfda3d2

Browse files
committed
[substrait] Add support for ExtensionTable
1 parent 6cfd1cf commit bfda3d2

4 files changed

Lines changed: 218 additions & 38 deletions

File tree

datafusion/core/src/execution/context/mod.rs

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use datafusion_expr::{
6363
expr_rewriter::FunctionRewrite,
6464
logical_plan::{DdlStatement, Statement},
6565
planner::ExprPlanner,
66-
Expr, UserDefinedLogicalNode, WindowUDF,
66+
Expr, WindowUDF,
6767
};
6868

6969
// backwards compatibility
@@ -1679,27 +1679,7 @@ pub enum RegisterFunction {
16791679
#[derive(Debug)]
16801680
pub struct EmptySerializerRegistry;
16811681

1682-
impl SerializerRegistry for EmptySerializerRegistry {
1683-
fn serialize_logical_plan(
1684-
&self,
1685-
node: &dyn UserDefinedLogicalNode,
1686-
) -> Result<Vec<u8>> {
1687-
not_impl_err!(
1688-
"Serializing user defined logical plan node `{}` is not supported",
1689-
node.name()
1690-
)
1691-
}
1692-
1693-
fn deserialize_logical_plan(
1694-
&self,
1695-
name: &str,
1696-
_bytes: &[u8],
1697-
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
1698-
not_impl_err!(
1699-
"Deserializing user defined logical plan node `{name}` is not supported"
1700-
)
1701-
}
1702-
}
1682+
impl SerializerRegistry for EmptySerializerRegistry {}
17031683

17041684
/// Describes which SQL statements can be run.
17051685
///

datafusion/expr/src/registry.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use crate::expr_rewriter::FunctionRewrite;
2121
use crate::planner::ExprPlanner;
22-
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
22+
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
2323
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
2424
use std::collections::HashSet;
2525
use std::fmt::Debug;
@@ -123,22 +123,52 @@ pub trait FunctionRegistry {
123123
}
124124
}
125125

126-
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
126+
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
127+
/// and custom table providers for which the name alone is meaningless in the target
128+
/// execution context, e.g. UDTFs, manually registered tables etc.
127129
pub trait SerializerRegistry: Debug + Send + Sync {
128130
/// Serialize this node to a byte array. This serialization should not include
129131
/// input plans.
130132
fn serialize_logical_plan(
131133
&self,
132134
node: &dyn UserDefinedLogicalNode,
133-
) -> Result<Vec<u8>>;
135+
) -> Result<Vec<u8>> {
136+
not_impl_err!(
137+
"Serializing user defined logical plan node `{}` is not supported",
138+
node.name()
139+
)
140+
}
134141

135142
/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
136143
/// bytes.
137144
fn deserialize_logical_plan(
138145
&self,
139146
name: &str,
140-
bytes: &[u8],
141-
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
147+
_bytes: &[u8],
148+
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
149+
not_impl_err!(
150+
"Deserializing user defined logical plan node `{name}` is not supported"
151+
)
152+
}
153+
154+
/// Serialized table definition for UDTFs or manually registered table providers that can't be
155+
/// marshaled by reference. Should return some benign error for regular tables that can be
156+
/// found/restored by name in the destination execution context.
157+
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
158+
not_impl_err!("No custom table support")
159+
}
160+
161+
/// Deserialize the custom table with the given name.
162+
/// Note: more often than not, the name can't be used as a discriminator if multiple different
163+
/// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true
164+
/// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`).
165+
fn deserialize_custom_table(
166+
&self,
167+
name: &str,
168+
_bytes: &[u8],
169+
) -> Result<Arc<dyn TableSource>> {
170+
not_impl_err!("Deserializing custom table `{name}` is not supported")
171+
}
142172
}
143173

144174
/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
3030

3131
use datafusion::logical_expr::{
3232
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
33-
LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values,
33+
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values,
3434
};
3535
use substrait::proto::aggregate_rel::Grouping;
3636
use substrait::proto::expression as substrait_expression;
@@ -86,6 +86,7 @@ use substrait::proto::expression::{
8686
SingularOrList, SwitchExpression, WindowFunction,
8787
};
8888
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
89+
use substrait::proto::read_rel::ExtensionTable;
8990
use substrait::proto::rel_common::{Emit, EmitKind};
9091
use substrait::proto::set_rel::SetOp;
9192
use substrait::proto::{
@@ -438,6 +439,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
438439
user_defined_literal.type_reference
439440
)
440441
}
442+
443+
fn consume_extension_table(
444+
&self,
445+
extension_table: &ExtensionTable,
446+
_schema: &DFSchema,
447+
_projection: &Option<MaskExpression>,
448+
) -> Result<LogicalPlan> {
449+
if let Some(ext_detail) = extension_table.detail.as_ref() {
450+
substrait_err!(
451+
"Missing handler for extension table: {}",
452+
&ext_detail.type_url
453+
)
454+
} else {
455+
substrait_err!("Unexpected empty detail in ExtensionTable")
456+
}
457+
}
441458
}
442459

443460
/// Convert Substrait Rel to DataFusion DataFrame
@@ -559,6 +576,32 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
559576
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
560577
Ok(LogicalPlan::Extension(Extension { node: plan }))
561578
}
579+
580+
fn consume_extension_table(
581+
&self,
582+
extension_table: &ExtensionTable,
583+
schema: &DFSchema,
584+
projection: &Option<MaskExpression>,
585+
) -> Result<LogicalPlan> {
586+
if let Some(ext_detail) = &extension_table.detail {
587+
let source = self
588+
.state
589+
.serializer_registry()
590+
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
591+
let table_name = ext_detail
592+
.type_url
593+
.rsplit_once('/')
594+
.map(|(_, name)| name)
595+
.unwrap_or(&ext_detail.type_url);
596+
let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?;
597+
let plan = LogicalPlan::TableScan(table_scan);
598+
ensure_schema_compatibility(plan.schema(), schema.clone())?;
599+
let schema = apply_masking(schema.clone(), projection)?;
600+
apply_projection(plan, schema)
601+
} else {
602+
substrait_err!("Unexpected empty detail in ExtensionTable")
603+
}
604+
}
562605
}
563606

564607
// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
@@ -1449,8 +1492,11 @@ pub async fn from_read_rel(
14491492
)
14501493
.await
14511494
}
1452-
_ => {
1453-
not_impl_err!("Unsupported ReadType: {:?}", read.read_type)
1495+
Some(ReadType::ExtensionTable(ext)) => {
1496+
consumer.consume_extension_table(ext, &substrait_schema, &read.projection)
1497+
}
1498+
None => {
1499+
substrait_err!("Unexpected empty read_type")
14541500
}
14551501
}
14561502
}

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 132 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use substrait::proto::expression::literal::{
6363
};
6464
use substrait::proto::expression::subquery::InPredicate;
6565
use substrait::proto::expression::window_function::BoundsType;
66-
use substrait::proto::read_rel::VirtualTable;
66+
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
6767
use substrait::proto::rel_common::EmitKind;
6868
use substrait::proto::rel_common::EmitKind::Emit;
6969
use substrait::proto::{
@@ -211,6 +211,23 @@ pub fn to_substrait_rel(
211211
let table_schema = scan.source.schema().to_dfschema_ref()?;
212212
let base_schema = to_substrait_named_struct(&table_schema)?;
213213

214+
let table = if let Ok(bytes) = state
215+
.serializer_registry()
216+
.serialize_custom_table(scan.source.as_ref())
217+
{
218+
ReadType::ExtensionTable(ExtensionTable {
219+
detail: Some(ProtoAny {
220+
type_url: scan.table_name.to_string(),
221+
value: bytes.into(),
222+
}),
223+
})
224+
} else {
225+
ReadType::NamedTable(NamedTable {
226+
names: scan.table_name.to_vec(),
227+
advanced_extension: None,
228+
})
229+
};
230+
214231
Ok(Box::new(Rel {
215232
rel_type: Some(RelType::Read(Box::new(ReadRel {
216233
common: None,
@@ -219,10 +236,7 @@ pub fn to_substrait_rel(
219236
best_effort_filter: None,
220237
projection,
221238
advanced_extension: None,
222-
read_type: Some(ReadType::NamedTable(NamedTable {
223-
names: scan.table_name.to_vec(),
224-
advanced_extension: None,
225-
})),
239+
read_type: Some(table),
226240
}))),
227241
}))
228242
}
@@ -2238,17 +2252,21 @@ mod test {
22382252
use super::*;
22392253
use crate::logical_plan::consumer::{
22402254
from_substrait_extended_expr, from_substrait_literal_without_names,
2241-
from_substrait_named_struct, from_substrait_type_without_names,
2242-
DefaultSubstraitConsumer,
2255+
from_substrait_named_struct, from_substrait_plan,
2256+
from_substrait_type_without_names, DefaultSubstraitConsumer,
22432257
};
22442258
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
22452259
use datafusion::arrow::array::{
22462260
GenericListArray, Int64Builder, MapBuilder, StringBuilder,
22472261
};
22482262
use datafusion::arrow::datatypes::{Field, Fields, Schema};
22492263
use datafusion::common::scalar::ScalarStructBuilder;
2250-
use datafusion::common::DFSchema;
2264+
use datafusion::common::{assert_contains, DFSchema};
2265+
use datafusion::datasource::empty::EmptyTable;
2266+
use datafusion::datasource::{DefaultTableSource, TableProvider};
22512267
use datafusion::execution::{SessionState, SessionStateBuilder};
2268+
use datafusion::logical_expr::registry::SerializerRegistry;
2269+
use datafusion::logical_expr::TableSource;
22522270
use datafusion::prelude::SessionContext;
22532271
use std::sync::OnceLock;
22542272

@@ -2585,4 +2603,110 @@ mod test {
25852603

25862604
assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
25872605
}
2606+
2607+
#[tokio::test]
2608+
async fn round_trip_extension_table() {
2609+
const TABLE_NAME: &str = "custom_table";
2610+
const SERIALIZED: &[u8] = "table definition".as_bytes();
2611+
2612+
fn custom_table() -> Arc<dyn TableProvider> {
2613+
Arc::new(EmptyTable::new(Arc::new(Schema::new([
2614+
Arc::new(Field::new("id", DataType::Int32, false)),
2615+
Arc::new(Field::new("name", DataType::Utf8, false)),
2616+
]))))
2617+
}
2618+
2619+
#[derive(Debug)]
2620+
struct Registry;
2621+
impl SerializerRegistry for Registry {
2622+
fn serialize_custom_table(&self, table: &dyn TableSource) -> Result<Vec<u8>> {
2623+
if table.schema() == custom_table().schema() {
2624+
Ok(SERIALIZED.to_vec())
2625+
} else {
2626+
Err(DataFusionError::Internal("Not our table".into()))
2627+
}
2628+
}
2629+
fn deserialize_custom_table(
2630+
&self,
2631+
name: &str,
2632+
bytes: &[u8],
2633+
) -> Result<Arc<dyn TableSource>> {
2634+
if name == TABLE_NAME && bytes == SERIALIZED {
2635+
Ok(Arc::new(DefaultTableSource::new(custom_table())))
2636+
} else {
2637+
panic!("Unexpected extension table: {name}");
2638+
}
2639+
}
2640+
}
2641+
2642+
async fn round_trip_logical_plans(
2643+
local: &SessionContext,
2644+
remote: &SessionContext,
2645+
) -> Result<()> {
2646+
local.register_table(TABLE_NAME, custom_table())?;
2647+
remote.table_provider(TABLE_NAME).await.expect_err(
2648+
"The remote context is not supposed to know about custom_table",
2649+
);
2650+
let initial_plan = local
2651+
.sql(&format!("select id from {TABLE_NAME}"))
2652+
.await?
2653+
.logical_plan()
2654+
.clone();
2655+
2656+
// write substrait locally
2657+
let substrait = to_substrait_plan(&initial_plan, &local.state())?;
2658+
2659+
// read substrait remotely
2660+
// since we know there's no `custom_table` registered in the remote context, this will only succeed
2661+
// if our table got encoded as an ExtensionTable and is now decoded back to a table source.
2662+
let restored = from_substrait_plan(&remote.state(), &substrait).await?;
2663+
assert_contains!(
2664+
// confirm that the Substrait plan contains our custom_table as an ExtensionTable
2665+
serde_json::to_string(substrait.as_ref()).unwrap(),
2666+
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#)
2667+
);
2668+
remote // make sure the restored plan is fully working in the remote context
2669+
.execute_logical_plan(restored.clone())
2670+
.await?
2671+
.collect()
2672+
.await
2673+
.expect("Restored plan cannot be executed remotely");
2674+
assert_eq!(
2675+
// check that the restored plan is functionally equivalent (and almost identical) to the initial one
2676+
initial_plan.to_string(),
2677+
restored.to_string().replace(
2678+
// substrait will add an explicit full-schema projection if the original table had none
2679+
&format!("TableScan: {TABLE_NAME} projection=[id, name]"),
2680+
&format!("TableScan: {TABLE_NAME}"),
2681+
)
2682+
);
2683+
Ok(())
2684+
}
2685+
2686+
// take 1
2687+
let failed_attempt =
2688+
round_trip_logical_plans(&SessionContext::new(), &SessionContext::new())
2689+
.await
2690+
.expect_err(
2691+
"The round trip should fail in the absence of a SerializerRegistry",
2692+
);
2693+
assert_contains!(
2694+
failed_attempt.message(),
2695+
format!("No table named '{TABLE_NAME}'")
2696+
);
2697+
2698+
// take 2
2699+
fn proper_context() -> SessionContext {
2700+
SessionContext::new_with_state(
2701+
SessionStateBuilder::new()
2702+
// This will transport our custom_table as a Substrait ExtensionTable
2703+
.with_serializer_registry(Arc::new(Registry))
2704+
.build(),
2705+
)
2706+
}
2707+
2708+
round_trip_logical_plans(&proper_context(), &proper_context())
2709+
.await
2710+
.expect("Local plan could not be restored remotely");
2711+
}
25882712
}

0 commit comments

Comments
 (0)