From 01e12995d4b31286a8499b61453071e3857deb31 Mon Sep 17 00:00:00 2001 From: Acfboy Date: Mon, 2 Mar 2026 10:35:11 +0800 Subject: [PATCH 1/5] feat: make DefaultLogicalExtensionCodec support serialisation of build in file formats --- datafusion/proto/proto/datafusion.proto | 14 ++ datafusion/proto/src/generated/pbjson.rs | 197 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 45 ++++ datafusion/proto/src/logical_plan/mod.rs | 102 +++++++++ .../tests/cases/roundtrip_logical_plan.rs | 183 ++++++++++++++++ 5 files changed, 541 insertions(+) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7c0268867691e..de26b477cd20b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -270,6 +270,20 @@ message CopyToNode { repeated string partition_by = 7; } +enum FileFormatKind { + FILE_FORMAT_KIND_UNSPECIFIED = 0; + FILE_FORMAT_KIND_CSV = 1; + FILE_FORMAT_KIND_JSON = 2; + FILE_FORMAT_KIND_PARQUET = 3; + FILE_FORMAT_KIND_ARROW = 4; + FILE_FORMAT_KIND_AVRO = 5; +} + +message FileFormatProto { + FileFormatKind kind = 1; + bytes options = 2; +} + message DmlNode{ enum Type { UPDATE = 0; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 5b2b9133ce13a..cec42f85c75bf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6076,6 +6076,203 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { deserializer.deserialize_struct("datafusion.ExplainNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FileFormatKind { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Unspecified => "FILE_FORMAT_KIND_UNSPECIFIED", + Self::Csv => "FILE_FORMAT_KIND_CSV", + Self::Json => "FILE_FORMAT_KIND_JSON", + Self::Parquet => "FILE_FORMAT_KIND_PARQUET", + Self::Arrow => "FILE_FORMAT_KIND_ARROW", + Self::Avro => "FILE_FORMAT_KIND_AVRO", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for FileFormatKind { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "FILE_FORMAT_KIND_UNSPECIFIED", + "FILE_FORMAT_KIND_CSV", + "FILE_FORMAT_KIND_JSON", + "FILE_FORMAT_KIND_PARQUET", + "FILE_FORMAT_KIND_ARROW", + "FILE_FORMAT_KIND_AVRO", + ]; + + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = FileFormatKind; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "FILE_FORMAT_KIND_UNSPECIFIED" => Ok(FileFormatKind::Unspecified), + "FILE_FORMAT_KIND_CSV" => Ok(FileFormatKind::Csv), + "FILE_FORMAT_KIND_JSON" => Ok(FileFormatKind::Json), + "FILE_FORMAT_KIND_PARQUET" => Ok(FileFormatKind::Parquet), + "FILE_FORMAT_KIND_ARROW" => Ok(FileFormatKind::Arrow), + "FILE_FORMAT_KIND_AVRO" => Ok(FileFormatKind::Avro), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for FileFormatProto { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.kind != 0 { + len += 1; + } + if !self.options.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileFormatProto", len)?; + if self.kind != 0 { + let v = FileFormatKind::try_from(self.kind) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.kind)))?; + struct_ser.serialize_field("kind", &v)?; + } + if !self.options.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("options", pbjson::private::base64::encode(&self.options).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FileFormatProto { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "kind", + "options", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Kind, + Options, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "kind" => Ok(GeneratedField::Kind), + "options" => Ok(GeneratedField::Options), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FileFormatProto; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FileFormatProto") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut kind__ = None; + let mut options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Kind => { + if kind__.is_some() { + return Err(serde::de::Error::duplicate_field("kind")); + } + kind__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + } + } + Ok(FileFormatProto { + kind: kind__.unwrap_or_default(), + options: options__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FileFormatProto", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for FileGroup { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d9602665c284a..3c98e46b49e5a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -412,6 +412,13 @@ pub struct CopyToNode { #[prost(string, repeated, tag = "7")] pub partition_by: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct FileFormatProto { + #[prost(enumeration = "FileFormatKind", tag = "1")] + pub kind: i32, + #[prost(bytes = "vec", tag = "2")] + pub options: ::prost::alloc::vec::Vec, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct DmlNode { #[prost(enumeration = "dml_node::Type", tag = "1")] @@ -2173,6 +2180,44 @@ pub struct BufferExecNode { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum FileFormatKind { + Unspecified = 0, + Csv = 1, + Json = 2, + Parquet = 3, + Arrow = 4, + Avro = 5, +} +impl FileFormatKind { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "FILE_FORMAT_KIND_UNSPECIFIED", + Self::Csv => "FILE_FORMAT_KIND_CSV", + Self::Json => "FILE_FORMAT_KIND_JSON", + Self::Parquet => "FILE_FORMAT_KIND_PARQUET", + Self::Arrow => "FILE_FORMAT_KIND_ARROW", + Self::Avro => "FILE_FORMAT_KIND_AVRO", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "FILE_FORMAT_KIND_UNSPECIFIED" => Some(Self::Unspecified), + "FILE_FORMAT_KIND_CSV" => Some(Self::Csv), + "FILE_FORMAT_KIND_JSON" => Some(Self::Json), + "FILE_FORMAT_KIND_PARQUET" => Some(Self::Parquet), + "FILE_FORMAT_KIND_ARROW" => Some(Self::Arrow), + "FILE_FORMAT_KIND_AVRO" => Some(Self::Avro), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum WindowFrameUnits { Rows = 0, Range = 1, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 218c2e4e47d04..9ac3989a62681 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -207,6 +207,108 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { ) -> Result<()> { not_impl_err!("LogicalExtensionCodec is not provided") } + + fn try_decode_file_format( + &self, + buf: &[u8], + ctx: &TaskContext, + ) -> Result> { + use prost::Message; + + let proto = protobuf::FileFormatProto::decode(buf) + .map_err(|e| internal_datafusion_err!("Failed to decode FileFormatProto: {e}"))?; + + let kind = protobuf::FileFormatKind::try_from(proto.kind) + .map_err(|_| internal_datafusion_err!("Unknown FileFormatKind: {}", proto.kind))?; + + match kind { + protobuf::FileFormatKind::Csv => { + file_formats::CsvLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx) + } + protobuf::FileFormatKind::Json => { + file_formats::JsonLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx) + } + #[cfg(feature = "parquet")] + protobuf::FileFormatKind::Parquet => { + file_formats::ParquetLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx) + } + protobuf::FileFormatKind::Arrow => { + file_formats::ArrowLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx) + } + protobuf::FileFormatKind::Avro => { + file_formats::AvroLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx) + } + #[cfg(not(feature = "parquet"))] + protobuf::FileFormatKind::Parquet => { + not_impl_err!("Parquet support requires the 'parquet' feature") + } + protobuf::FileFormatKind::Unspecified => { + not_impl_err!("Unspecified file format kind") + } + } + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> Result<()> { + use datafusion_datasource_arrow::file_format::ArrowFormatFactory; + use datafusion_datasource_csv::file_format::CsvFormatFactory; + use datafusion_datasource_json::file_format::JsonFormatFactory; + use prost::Message; + + let any = node.as_any(); + let mut options = Vec::new(); + + let kind = if any.downcast_ref::().is_some() { + file_formats::CsvLogicalExtensionCodec + .try_encode_file_format(&mut options, Arc::clone(&node))?; + protobuf::FileFormatKind::Csv + } else if any.downcast_ref::().is_some() { + file_formats::JsonLogicalExtensionCodec + .try_encode_file_format(&mut options, Arc::clone(&node))?; + protobuf::FileFormatKind::Json + } else if any.downcast_ref::().is_some() { + file_formats::ArrowLogicalExtensionCodec + .try_encode_file_format(&mut options, Arc::clone(&node))?; + protobuf::FileFormatKind::Arrow + } else { + #[cfg(feature = "parquet")] + { + use datafusion_datasource_parquet::file_format::ParquetFormatFactory; + if any.downcast_ref::().is_some() { + file_formats::ParquetLogicalExtensionCodec + .try_encode_file_format(&mut options, Arc::clone(&node))?; + protobuf::FileFormatKind::Parquet + } else { + return not_impl_err!( + "Unsupported FileFormatFactory type for DefaultLogicalExtensionCodec" + ); + } + } + #[cfg(not(feature = "parquet"))] + { + return not_impl_err!( + "Unsupported FileFormatFactory type for DefaultLogicalExtensionCodec" + ); + } + }; + + let proto = protobuf::FileFormatProto { + kind: kind as i32, + options, + }; + proto + .encode(buf) + .map_err(|e| internal_datafusion_err!("Failed to encode FileFormatProto: {e}"))?; + Ok(()) + } } #[macro_export] diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9407cbf9a0749..6731f4dd04f61 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -743,6 +743,189 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_default_codec_csv() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_csv_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut csv_format = table_options.csv; + csv_format.delimiter = b'|'; + csv_format.has_header = Some(true); + csv_format.compression = CompressionTypeVariant::GZIP; + + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( + csv_format.clone(), + ))); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.csv".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!("csv", copy_to.file_type.get_ext()); + let dt = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let csv = dt + .as_format_factory() + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let decoded = csv.options.as_ref().unwrap(); + assert_eq!(csv_format.delimiter, decoded.delimiter); + assert_eq!(csv_format.has_header, decoded.has_header); + assert_eq!(csv_format.compression, decoded.compression); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_default_codec_json() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_json_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut json_format = table_options.json; + json_format.compression = CompressionTypeVariant::GZIP; + json_format.schema_infer_max_rec = Some(500); + + let file_type = format_as_file_type(Arc::new(JsonFormatFactory::new_with_options( + json_format.clone(), + ))); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.json".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.json", copy_to.output_url); + assert_eq!("json", copy_to.file_type.get_ext()); + let dt = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let json = dt + .as_format_factory() + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let decoded = json.options.as_ref().unwrap(); + assert_eq!(json_format.compression, decoded.compression); + assert_eq!(json_format.schema_infer_max_rec, decoded.schema_infer_max_rec); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_default_codec_parquet() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_parquet_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut parquet_format = table_options.parquet; + parquet_format.global.bloom_filter_on_read = true; + parquet_format.global.created_by = "DefaultCodecTest".to_string(); + + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format.clone()), + )); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.parquet".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!("parquet", copy_to.file_type.get_ext()); + let dt = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let pq = dt + .as_format_factory() + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let decoded = pq.options.as_ref().unwrap(); + assert!(decoded.global.bloom_filter_on_read); + assert_eq!("DefaultCodecTest", decoded.global.created_by); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_default_codec_arrow() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_csv_scan(&ctx).await?; + + let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.arrow".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.arrow", copy_to.output_url); + assert_eq!("arrow", copy_to.file_type.get_ext()); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; From e2275e2fd97e0b9ab288de0328b726e997d43980 Mon Sep 17 00:00:00 2001 From: Acfboy Date: Mon, 2 Mar 2026 10:57:35 +0800 Subject: [PATCH 2/5] fmt --- datafusion/proto/src/logical_plan/mod.rs | 40 ++++++++----------- .../tests/cases/roundtrip_logical_plan.rs | 5 ++- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9ac3989a62681..590a81769b5e4 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -215,34 +215,28 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { ) -> Result> { use prost::Message; - let proto = protobuf::FileFormatProto::decode(buf) - .map_err(|e| internal_datafusion_err!("Failed to decode FileFormatProto: {e}"))?; + let proto = protobuf::FileFormatProto::decode(buf).map_err(|e| { + internal_datafusion_err!("Failed to decode FileFormatProto: {e}") + })?; - let kind = protobuf::FileFormatKind::try_from(proto.kind) - .map_err(|_| internal_datafusion_err!("Unknown FileFormatKind: {}", proto.kind))?; + let kind = protobuf::FileFormatKind::try_from(proto.kind).map_err(|_| { + internal_datafusion_err!("Unknown FileFormatKind: {}", proto.kind) + })?; match kind { - protobuf::FileFormatKind::Csv => { - file_formats::CsvLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx) - } - protobuf::FileFormatKind::Json => { - file_formats::JsonLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx) - } + protobuf::FileFormatKind::Csv => file_formats::CsvLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx), + protobuf::FileFormatKind::Json => file_formats::JsonLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx), #[cfg(feature = "parquet")] protobuf::FileFormatKind::Parquet => { file_formats::ParquetLogicalExtensionCodec .try_decode_file_format(&proto.options, ctx) } - protobuf::FileFormatKind::Arrow => { - file_formats::ArrowLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx) - } - protobuf::FileFormatKind::Avro => { - file_formats::AvroLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx) - } + protobuf::FileFormatKind::Arrow => file_formats::ArrowLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx), + protobuf::FileFormatKind::Avro => file_formats::AvroLogicalExtensionCodec + .try_decode_file_format(&proto.options, ctx), #[cfg(not(feature = "parquet"))] protobuf::FileFormatKind::Parquet => { not_impl_err!("Parquet support requires the 'parquet' feature") @@ -304,9 +298,9 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { kind: kind as i32, options, }; - proto - .encode(buf) - .map_err(|e| internal_datafusion_err!("Failed to encode FileFormatProto: {e}"))?; + proto.encode(buf).map_err(|e| { + internal_datafusion_err!("Failed to encode FileFormatProto: {e}") + })?; Ok(()) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6731f4dd04f61..63ad00c92e6a9 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -840,7 +840,10 @@ async fn roundtrip_default_codec_json() -> Result<()> { .unwrap(); let decoded = json.options.as_ref().unwrap(); assert_eq!(json_format.compression, decoded.compression); - assert_eq!(json_format.schema_infer_max_rec, decoded.schema_infer_max_rec); + assert_eq!( + json_format.schema_infer_max_rec, + decoded.schema_infer_max_rec + ); } _ => panic!("Expected CopyTo plan"), } From 7364d53db23fd97aaf1166c1022fa960ae41b162 Mon Sep 17 00:00:00 2001 From: Acfboy Date: Fri, 6 Mar 2026 14:45:35 +0800 Subject: [PATCH 3/5] address review --- datafusion/proto/proto/datafusion.proto | 7 +++- datafusion/proto/src/generated/pbjson.rs | 25 +++++++------- datafusion/proto/src/generated/prost.rs | 7 +++- datafusion/proto/src/logical_plan/mod.rs | 44 +++++++++++------------- 4 files changed, 45 insertions(+), 38 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index de26b477cd20b..25ecdccfcde69 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -270,6 +270,9 @@ message CopyToNode { repeated string partition_by = 7; } +// Identifies a built-in file format supported by DataFusion. +// Used by DefaultLogicalExtensionCodec to serialize/deserialize +// FileFormatFactory instances (e.g. in CopyTo plans). enum FileFormatKind { FILE_FORMAT_KIND_UNSPECIFIED = 0; FILE_FORMAT_KIND_CSV = 1; @@ -279,9 +282,11 @@ enum FileFormatKind { FILE_FORMAT_KIND_AVRO = 5; } +// Wraps a serialized FileFormatFactory with its format kind tag, +// so the decoder can dispatch to the correct format-specific codec. message FileFormatProto { FileFormatKind kind = 1; - bytes options = 2; + bytes encoded_file_format = 2; } message DmlNode{ diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cec42f85c75bf..148493b3c1c0f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -6170,7 +6170,7 @@ impl serde::Serialize for FileFormatProto { if self.kind != 0 { len += 1; } - if !self.options.is_empty() { + if !self.encoded_file_format.is_empty() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.FileFormatProto", len)?; @@ -6179,10 +6179,10 @@ impl serde::Serialize for FileFormatProto { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.kind)))?; struct_ser.serialize_field("kind", &v)?; } - if !self.options.is_empty() { + if !self.encoded_file_format.is_empty() { #[allow(clippy::needless_borrow)] #[allow(clippy::needless_borrows_for_generic_args)] - struct_ser.serialize_field("options", pbjson::private::base64::encode(&self.options).as_str())?; + struct_ser.serialize_field("encodedFileFormat", pbjson::private::base64::encode(&self.encoded_file_format).as_str())?; } struct_ser.end() } @@ -6195,13 +6195,14 @@ impl<'de> serde::Deserialize<'de> for FileFormatProto { { const FIELDS: &[&str] = &[ "kind", - "options", + "encoded_file_format", + "encodedFileFormat", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Kind, - Options, + EncodedFileFormat, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6224,7 +6225,7 @@ impl<'de> serde::Deserialize<'de> for FileFormatProto { { match value { "kind" => Ok(GeneratedField::Kind), - "options" => Ok(GeneratedField::Options), + "encodedFileFormat" | "encoded_file_format" => Ok(GeneratedField::EncodedFileFormat), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6245,7 +6246,7 @@ impl<'de> serde::Deserialize<'de> for FileFormatProto { V: serde::de::MapAccess<'de>, { let mut kind__ = None; - let mut options__ = None; + let mut encoded_file_format__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Kind => { @@ -6254,11 +6255,11 @@ impl<'de> serde::Deserialize<'de> for FileFormatProto { } kind__ = Some(map_.next_value::()? as i32); } - GeneratedField::Options => { - if options__.is_some() { - return Err(serde::de::Error::duplicate_field("options")); + GeneratedField::EncodedFileFormat => { + if encoded_file_format__.is_some() { + return Err(serde::de::Error::duplicate_field("encodedFileFormat")); } - options__ = + encoded_file_format__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -6266,7 +6267,7 @@ impl<'de> serde::Deserialize<'de> for FileFormatProto { } Ok(FileFormatProto { kind: kind__.unwrap_or_default(), - options: options__.unwrap_or_default(), + encoded_file_format: encoded_file_format__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3c98e46b49e5a..4fce6ead6f39e 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -412,12 +412,14 @@ pub struct CopyToNode { #[prost(string, repeated, tag = "7")] pub partition_by: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } +/// Wraps a serialized FileFormatFactory with its format kind tag, +/// so the decoder can dispatch to the correct format-specific codec. #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FileFormatProto { #[prost(enumeration = "FileFormatKind", tag = "1")] pub kind: i32, #[prost(bytes = "vec", tag = "2")] - pub options: ::prost::alloc::vec::Vec, + pub encoded_file_format: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct DmlNode { @@ -2178,6 +2180,9 @@ pub struct BufferExecNode { #[prost(uint64, tag = "2")] pub capacity: u64, } +/// Identifies a built-in file format supported by DataFusion. +/// Used by DefaultLogicalExtensionCodec to serialize/deserialize +/// FileFormatFactory instances (e.g. in CopyTo plans). #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum FileFormatKind { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 590a81769b5e4..34a27e3e453b7 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -44,13 +44,15 @@ use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_format::{ FileFormatFactory, file_type_to_format, format_as_file_type, }; -use datafusion_datasource_arrow::file_format::ArrowFormat; +use datafusion_datasource_arrow::file_format::{ArrowFormat, ArrowFormatFactory}; #[cfg(feature = "avro")] use datafusion_datasource_avro::file_format::AvroFormat; -use datafusion_datasource_csv::file_format::CsvFormat; -use datafusion_datasource_json::file_format::JsonFormat as OtherNdJsonFormat; +use datafusion_datasource_csv::file_format::{CsvFormat, CsvFormatFactory}; +use datafusion_datasource_json::file_format::{ + JsonFormat as OtherNdJsonFormat, JsonFormatFactory, +}; #[cfg(feature = "parquet")] -use datafusion_datasource_parquet::file_format::ParquetFormat; +use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::{ AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, }; @@ -213,8 +215,6 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { buf: &[u8], ctx: &TaskContext, ) -> Result> { - use prost::Message; - let proto = protobuf::FileFormatProto::decode(buf).map_err(|e| { internal_datafusion_err!("Failed to decode FileFormatProto: {e}") })?; @@ -225,18 +225,18 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { match kind { protobuf::FileFormatKind::Csv => file_formats::CsvLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx), + .try_decode_file_format(&proto.encoded_file_format, ctx), protobuf::FileFormatKind::Json => file_formats::JsonLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx), + .try_decode_file_format(&proto.encoded_file_format, ctx), #[cfg(feature = "parquet")] protobuf::FileFormatKind::Parquet => { file_formats::ParquetLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx) + .try_decode_file_format(&proto.encoded_file_format, ctx) } protobuf::FileFormatKind::Arrow => file_formats::ArrowLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx), + .try_decode_file_format(&proto.encoded_file_format, ctx), protobuf::FileFormatKind::Avro => file_formats::AvroLogicalExtensionCodec - .try_decode_file_format(&proto.options, ctx), + .try_decode_file_format(&proto.encoded_file_format, ctx), #[cfg(not(feature = "parquet"))] protobuf::FileFormatKind::Parquet => { not_impl_err!("Parquet support requires the 'parquet' feature") @@ -252,33 +252,29 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { buf: &mut Vec, node: Arc, ) -> Result<()> { - use datafusion_datasource_arrow::file_format::ArrowFormatFactory; - use datafusion_datasource_csv::file_format::CsvFormatFactory; - use datafusion_datasource_json::file_format::JsonFormatFactory; - use prost::Message; + let mut encoded_file_format = Vec::new(); let any = node.as_any(); - let mut options = Vec::new(); - let kind = if any.downcast_ref::().is_some() { file_formats::CsvLogicalExtensionCodec - .try_encode_file_format(&mut options, Arc::clone(&node))?; + .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; protobuf::FileFormatKind::Csv } else if any.downcast_ref::().is_some() { file_formats::JsonLogicalExtensionCodec - .try_encode_file_format(&mut options, Arc::clone(&node))?; + .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; protobuf::FileFormatKind::Json } else if any.downcast_ref::().is_some() { file_formats::ArrowLogicalExtensionCodec - .try_encode_file_format(&mut options, Arc::clone(&node))?; + .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; protobuf::FileFormatKind::Arrow } else { #[cfg(feature = "parquet")] { - use datafusion_datasource_parquet::file_format::ParquetFormatFactory; if any.downcast_ref::().is_some() { - file_formats::ParquetLogicalExtensionCodec - .try_encode_file_format(&mut options, Arc::clone(&node))?; + file_formats::ParquetLogicalExtensionCodec.try_encode_file_format( + &mut encoded_file_format, + Arc::clone(&node), + )?; protobuf::FileFormatKind::Parquet } else { return not_impl_err!( @@ -296,7 +292,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { let proto = protobuf::FileFormatProto { kind: kind as i32, - options, + encoded_file_format, }; proto.encode(buf).map_err(|e| { internal_datafusion_err!("Failed to encode FileFormatProto: {e}") From 70ca0680dfca4fd47daedd5935c120e32c06e79d Mon Sep 17 00:00:00 2001 From: Acfboy Date: Fri, 6 Mar 2026 22:21:06 +0800 Subject: [PATCH 4/5] address review: extract variable --- datafusion/proto/src/logical_plan/mod.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 34a27e3e453b7..d2ebf40bdb847 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -254,23 +254,22 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { ) -> Result<()> { let mut encoded_file_format = Vec::new(); - let any = node.as_any(); - let kind = if any.downcast_ref::().is_some() { + let kind = if node.as_any().downcast_ref::().is_some() { file_formats::CsvLogicalExtensionCodec .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; protobuf::FileFormatKind::Csv - } else if any.downcast_ref::().is_some() { + } else if node.as_any().downcast_ref::().is_some() { file_formats::JsonLogicalExtensionCodec .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; protobuf::FileFormatKind::Json - } else if any.downcast_ref::().is_some() { + } else if node.as_any().downcast_ref::().is_some() { file_formats::ArrowLogicalExtensionCodec .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; protobuf::FileFormatKind::Arrow } else { #[cfg(feature = "parquet")] { - if any.downcast_ref::().is_some() { + if node.as_any().downcast_ref::().is_some() { file_formats::ParquetLogicalExtensionCodec.try_encode_file_format( &mut encoded_file_format, Arc::clone(&node), From 9cdb1fab924d0d21a7a1fa567402d939fb9f6887 Mon Sep 17 00:00:00 2001 From: Acfboy Date: Fri, 6 Mar 2026 23:46:21 +0800 Subject: [PATCH 5/5] fmt --- datafusion/proto/src/logical_plan/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d2ebf40bdb847..a5d74d7f49fae 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -269,7 +269,11 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { } else { #[cfg(feature = "parquet")] { - if node.as_any().downcast_ref::().is_some() { + if node + .as_any() + .downcast_ref::() + .is_some() + { file_formats::ParquetLogicalExtensionCodec.try_encode_file_format( &mut encoded_file_format, Arc::clone(&node),