@@ -63,7 +63,7 @@ use substrait::proto::expression::literal::{
6363} ;
6464use substrait:: proto:: expression:: subquery:: InPredicate ;
6565use substrait:: proto:: expression:: window_function:: BoundsType ;
66- use substrait:: proto:: read_rel:: VirtualTable ;
66+ use substrait:: proto:: read_rel:: { ExtensionTable , VirtualTable } ;
6767use substrait:: proto:: rel_common:: EmitKind ;
6868use substrait:: proto:: rel_common:: EmitKind :: Emit ;
6969use substrait:: proto:: {
@@ -211,6 +211,23 @@ pub fn to_substrait_rel(
211211 let table_schema = scan. source . schema ( ) . to_dfschema_ref ( ) ?;
212212 let base_schema = to_substrait_named_struct ( & table_schema) ?;
213213
214+ let table = if let Ok ( bytes) = state
215+ . serializer_registry ( )
216+ . serialize_custom_table ( scan. source . as_ref ( ) )
217+ {
218+ ReadType :: ExtensionTable ( ExtensionTable {
219+ detail : Some ( ProtoAny {
220+ type_url : scan. table_name . to_string ( ) ,
221+ value : bytes. into ( ) ,
222+ } ) ,
223+ } )
224+ } else {
225+ ReadType :: NamedTable ( NamedTable {
226+ names : scan. table_name . to_vec ( ) ,
227+ advanced_extension : None ,
228+ } )
229+ } ;
230+
214231 Ok ( Box :: new ( Rel {
215232 rel_type : Some ( RelType :: Read ( Box :: new ( ReadRel {
216233 common : None ,
@@ -219,10 +236,7 @@ pub fn to_substrait_rel(
219236 best_effort_filter : None ,
220237 projection,
221238 advanced_extension : None ,
222- read_type : Some ( ReadType :: NamedTable ( NamedTable {
223- names : scan. table_name . to_vec ( ) ,
224- advanced_extension : None ,
225- } ) ) ,
239+ read_type : Some ( table) ,
226240 } ) ) ) ,
227241 } ) )
228242 }
@@ -2238,17 +2252,21 @@ mod test {
22382252 use super :: * ;
22392253 use crate :: logical_plan:: consumer:: {
22402254 from_substrait_extended_expr, from_substrait_literal_without_names,
2241- from_substrait_named_struct, from_substrait_type_without_names ,
2242- DefaultSubstraitConsumer ,
2255+ from_substrait_named_struct, from_substrait_plan ,
2256+ from_substrait_type_without_names , DefaultSubstraitConsumer ,
22432257 } ;
22442258 use arrow_buffer:: { IntervalDayTime , IntervalMonthDayNano } ;
22452259 use datafusion:: arrow:: array:: {
22462260 GenericListArray , Int64Builder , MapBuilder , StringBuilder ,
22472261 } ;
22482262 use datafusion:: arrow:: datatypes:: { Field , Fields , Schema } ;
22492263 use datafusion:: common:: scalar:: ScalarStructBuilder ;
2250- use datafusion:: common:: DFSchema ;
2264+ use datafusion:: common:: { assert_contains, DFSchema } ;
2265+ use datafusion:: datasource:: empty:: EmptyTable ;
2266+ use datafusion:: datasource:: { DefaultTableSource , TableProvider } ;
22512267 use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
2268+ use datafusion:: logical_expr:: registry:: SerializerRegistry ;
2269+ use datafusion:: logical_expr:: TableSource ;
22522270 use datafusion:: prelude:: SessionContext ;
22532271 use std:: sync:: OnceLock ;
22542272
@@ -2585,4 +2603,110 @@ mod test {
25852603
25862604 assert ! ( matches!( err, Err ( DataFusionError :: SchemaError ( _, _) ) ) ) ;
25872605 }
2606+
2607+ #[ tokio:: test]
2608+ async fn round_trip_extension_table ( ) {
2609+ const TABLE_NAME : & str = "custom_table" ;
2610+ const SERIALIZED : & [ u8 ] = "table definition" . as_bytes ( ) ;
2611+
2612+ fn custom_table ( ) -> Arc < dyn TableProvider > {
2613+ Arc :: new ( EmptyTable :: new ( Arc :: new ( Schema :: new ( [
2614+ Arc :: new ( Field :: new ( "id" , DataType :: Int32 , false ) ) ,
2615+ Arc :: new ( Field :: new ( "name" , DataType :: Utf8 , false ) ) ,
2616+ ] ) ) ) )
2617+ }
2618+
2619+ #[ derive( Debug ) ]
2620+ struct Registry ;
2621+ impl SerializerRegistry for Registry {
2622+ fn serialize_custom_table ( & self , table : & dyn TableSource ) -> Result < Vec < u8 > > {
2623+ if table. schema ( ) == custom_table ( ) . schema ( ) {
2624+ Ok ( SERIALIZED . to_vec ( ) )
2625+ } else {
2626+ Err ( DataFusionError :: Internal ( "Not our table" . into ( ) ) )
2627+ }
2628+ }
2629+ fn deserialize_custom_table (
2630+ & self ,
2631+ name : & str ,
2632+ bytes : & [ u8 ] ,
2633+ ) -> Result < Arc < dyn TableSource > > {
2634+ if name == TABLE_NAME && bytes == SERIALIZED {
2635+ Ok ( Arc :: new ( DefaultTableSource :: new ( custom_table ( ) ) ) )
2636+ } else {
2637+ panic ! ( "Unexpected extension table: {name}" ) ;
2638+ }
2639+ }
2640+ }
2641+
2642+ async fn round_trip_logical_plans (
2643+ local : & SessionContext ,
2644+ remote : & SessionContext ,
2645+ ) -> Result < ( ) > {
2646+ local. register_table ( TABLE_NAME , custom_table ( ) ) ?;
2647+ remote. table_provider ( TABLE_NAME ) . await . expect_err (
2648+ "The remote context is not supposed to know about custom_table" ,
2649+ ) ;
2650+ let initial_plan = local
2651+ . sql ( & format ! ( "select id from {TABLE_NAME}" ) )
2652+ . await ?
2653+ . logical_plan ( )
2654+ . clone ( ) ;
2655+
2656+ // write substrait locally
2657+ let substrait = to_substrait_plan ( & initial_plan, & local. state ( ) ) ?;
2658+
2659+ // read substrait remotely
2660+ // since we know there's no `custom_table` registered in the remote context, this will only succeed
2661+ // if our table got encoded as an ExtensionTable and is now decoded back to a table source.
2662+ let restored = from_substrait_plan ( & remote. state ( ) , & substrait) . await ?;
2663+ assert_contains ! (
2664+ // confirm that the Substrait plan contains our custom_table as an ExtensionTable
2665+ serde_json:: to_string( substrait. as_ref( ) ) . unwrap( ) ,
2666+ format!( r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","# )
2667+ ) ;
2668+ remote // make sure the restored plan is fully working in the remote context
2669+ . execute_logical_plan ( restored. clone ( ) )
2670+ . await ?
2671+ . collect ( )
2672+ . await
2673+ . expect ( "Restored plan cannot be executed remotely" ) ;
2674+ assert_eq ! (
2675+ // check that the restored plan is functionally equivalent (and almost identical) to the initial one
2676+ initial_plan. to_string( ) ,
2677+ restored. to_string( ) . replace(
2678+ // substrait will add an explicit full-schema projection if the original table had none
2679+ & format!( "TableScan: {TABLE_NAME} projection=[id, name]" ) ,
2680+ & format!( "TableScan: {TABLE_NAME}" ) ,
2681+ )
2682+ ) ;
2683+ Ok ( ( ) )
2684+ }
2685+
2686+ // take 1
2687+ let failed_attempt =
2688+ round_trip_logical_plans ( & SessionContext :: new ( ) , & SessionContext :: new ( ) )
2689+ . await
2690+ . expect_err (
2691+ "The round trip should fail in the absence of a SerializerRegistry" ,
2692+ ) ;
2693+ assert_contains ! (
2694+ failed_attempt. message( ) ,
2695+ format!( "No table named '{TABLE_NAME}'" )
2696+ ) ;
2697+
2698+ // take 2
2699+ fn proper_context ( ) -> SessionContext {
2700+ SessionContext :: new_with_state (
2701+ SessionStateBuilder :: new ( )
2702+ // This will transport our custom_table as a Substrait ExtensionTable
2703+ . with_serializer_registry ( Arc :: new ( Registry ) )
2704+ . build ( ) ,
2705+ )
2706+ }
2707+
2708+ round_trip_logical_plans ( & proper_context ( ) , & proper_context ( ) )
2709+ . await
2710+ . expect ( "Local plan could not be restored remotely" ) ;
2711+ }
25882712}
0 commit comments