diff --git a/datafusion/proto-models/proto/datafusion.proto b/datafusion/proto-models/proto/datafusion.proto index 2f5b75e40937e..72bc3e6bc0941 100644 --- a/datafusion/proto-models/proto/datafusion.proto +++ b/datafusion/proto-models/proto/datafusion.proto @@ -148,9 +148,19 @@ message RepartitionNode { oneof partition_method { uint64 round_robin = 2; HashRepartition hash = 3; + RangeRepartition range = 4; } } +message RangeSplitPoint { + repeated datafusion_common.ScalarValue value = 1; +} + +message RangeRepartition { + repeated SortExprNode expr = 1; + repeated RangeSplitPoint split_points = 2; +} + message HashRepartition { repeated LogicalExprNode hash_expr = 1; uint64 partition_count = 2; diff --git a/datafusion/proto-models/src/generated/pbjson.rs b/datafusion/proto-models/src/generated/pbjson.rs index 733da68fe89c2..21d33270e7a87 100644 --- a/datafusion/proto-models/src/generated/pbjson.rs +++ b/datafusion/proto-models/src/generated/pbjson.rs @@ -22350,6 +22350,206 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for RangeRepartition { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.expr.is_empty() { + len += 1; + } + if !self.split_points.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RangeRepartition", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + if !self.split_points.is_empty() { + struct_ser.serialize_field("splitPoints", &self.split_points)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RangeRepartition { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + "split_points", + "splitPoints", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + SplitPoints, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + "splitPoints" | "split_points" => Ok(GeneratedField::SplitPoints), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RangeRepartition; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RangeRepartition") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + let mut split_points__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + GeneratedField::SplitPoints => { + if split_points__.is_some() { + return Err(serde::de::Error::duplicate_field("splitPoints")); + } + split_points__ = Some(map_.next_value()?); + } + } + } + Ok(RangeRepartition { + expr: expr__.unwrap_or_default(), + split_points: split_points__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RangeRepartition", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for RangeSplitPoint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RangeSplitPoint", len)?; + if !self.value.is_empty() { + struct_ser.serialize_field("value", &self.value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RangeSplitPoint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RangeSplitPoint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RangeSplitPoint") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = Some(map_.next_value()?); + } + } + } + Ok(RangeSplitPoint { + value: value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.RangeSplitPoint", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for RecursionUnnestOption { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -22778,6 +22978,9 @@ impl serde::Serialize for RepartitionNode { repartition_node::PartitionMethod::Hash(v) => { struct_ser.serialize_field("hash", v)?; } + repartition_node::PartitionMethod::Range(v) => { + struct_ser.serialize_field("range", v)?; + } } } struct_ser.end() @@ -22794,6 +22997,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { "round_robin", "roundRobin", "hash", + "range", ]; #[allow(clippy::enum_variant_names)] @@ -22801,6 +23005,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { Input, RoundRobin, Hash, + Range, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22825,6 +23030,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { "input" => Ok(GeneratedField::Input), "roundRobin" | "round_robin" => Ok(GeneratedField::RoundRobin), "hash" => Ok(GeneratedField::Hash), + "range" => Ok(GeneratedField::Range), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22865,6 +23071,13 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { return Err(serde::de::Error::duplicate_field("hash")); } partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) +; + } + GeneratedField::Range => { + if partition_method__.is_some() { + return Err(serde::de::Error::duplicate_field("range")); + } + partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Range) ; } } diff --git a/datafusion/proto-models/src/generated/prost.rs b/datafusion/proto-models/src/generated/prost.rs index 4a2edeeb11eca..a821be9e289cc 100644 --- a/datafusion/proto-models/src/generated/prost.rs +++ b/datafusion/proto-models/src/generated/prost.rs @@ -210,7 +210,7 @@ pub struct SortNode { pub struct RepartitionNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "repartition_node::PartitionMethod", tags = "2, 3")] + #[prost(oneof = "repartition_node::PartitionMethod", tags = "2, 3, 4")] pub partition_method: ::core::option::Option, } /// Nested message and enum types in `RepartitionNode`. @@ -221,9 +221,23 @@ pub mod repartition_node { RoundRobin(u64), #[prost(message, tag = "3")] Hash(super::HashRepartition), + #[prost(message, tag = "4")] + Range(super::RangeRepartition), } } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct RangeSplitPoint { + #[prost(message, repeated, tag = "1")] + pub value: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RangeRepartition { + #[prost(message, repeated, tag = "1")] + pub expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub split_points: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct HashRepartition { #[prost(message, repeated, tag = "1")] pub hash_expr: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b79b21b3599c7..cda1de3a3272a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field}; use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ - NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, + NullEquality, RecursionUnnestOption, Result, ScalarValue, SplitPoint, TableReference, UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err, }; use datafusion_execution::TaskContext; @@ -899,3 +899,14 @@ fn parse_required_expr( fn proto_error>(message: S) -> Error { Error::General(message.into()) } + +pub fn parse_protobuf_range_split_point( + split_point: &protobuf::RangeSplitPoint, +) -> Result { + let values = split_point + .value + .iter() + .map(ScalarValue::try_from) + .collect::>()?; + Ok(SplitPoint::new(values)) +} diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a0604cb6b03e6..b4c6a7be7a15a 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -60,8 +60,8 @@ use datafusion_datasource_json::file_format::{ use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::dml::InsertOp; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RecursiveQuery, SkipType, - TableSource, Unnest, WriteOp, + AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RangePartitioning, + RecursiveQuery, SkipType, TableSource, Unnest, WriteOp, }; use datafusion_expr::{ DistinctOn, DropView, Expr, JoinConstraint, LogicalPlan, LogicalPlanBuilder, @@ -76,6 +76,7 @@ use datafusion_expr::{ use datafusion_proto_common::protobuf_common; use self::to_proto::{serialize_expr, serialize_exprs}; +use crate::logical_plan::to_proto::serialize_range_split_point; use crate::logical_plan::to_proto::serialize_sorts; use datafusion_catalog::TableProvider; use datafusion_catalog::default_table_source::{provider_as_source, source_as_provider}; @@ -745,6 +746,16 @@ impl AsLogicalPlan for LogicalPlanNode { PartitionMethod::RoundRobin(partition_count) => { Partitioning::RoundRobinBatch(*partition_count as usize) } + PartitionMethod::Range(protobuf::RangeRepartition { + expr: pb_sort_expr, + split_points, + }) => Partitioning::Range(RangePartitioning::try_new( + from_proto::parse_sorts(pb_sort_expr, ctx, extension_codec)?, + split_points + .iter() + .map(from_proto::parse_protobuf_range_split_point) + .collect::, _>>()?, + )?), }; LogicalPlanBuilder::from(input) @@ -1754,10 +1765,18 @@ impl AsLogicalPlan for LogicalPlanNode { Partitioning::RoundRobinBatch(partition_count) => { PartitionMethod::RoundRobin(*partition_count as u64) } - Partitioning::Range(_) => { - // TODO: Support range repartition protobuf serialization. - // Tracked by https://github.com/apache/datafusion/issues/22787 - return not_impl_err!("Range repartition"); + Partitioning::Range(range_partitioning) => { + let ordering = range_partitioning.ordering(); + let split_points = range_partitioning + .split_points() + .iter() + .map(serialize_range_split_point) + .collect::, _>>()?; + + PartitionMethod::Range(protobuf::RangeRepartition { + expr: serialize_sorts(ordering, extension_codec)?, + split_points, + }) } Partitioning::DistributeBy(_) => { return not_impl_err!("DistributeBy"); diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 516aca4094451..dda49bbd4faaf 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; -use datafusion_common::{NullEquality, TableReference, UnnestOptions}; +use datafusion_common::{NullEquality, SplitPoint, TableReference, UnnestOptions}; use datafusion_expr::dml::{ MergeIntoAction, MergeIntoClause, MergeIntoClauseKind, MergeIntoOp, }; @@ -687,6 +687,18 @@ where .collect::, Error>>() } +pub fn serialize_range_split_point( + split_point: &SplitPoint, +) -> Result { + Ok(protobuf::RangeSplitPoint { + value: split_point + .values() + .iter() + .map(TryInto::::try_into) + .collect::>()?, + }) +} + impl FromProto for protobuf::TableReference { fn from_proto(t: TableReference) -> Self { use protobuf::table_reference::TableReferenceEnum; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9d8e5c2b1ef48..0d2b58e5af67f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -73,8 +73,8 @@ use datafusion_common::format::{ }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ - DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, - internal_datafusion_err, internal_err, not_impl_err, plan_err, + DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, SplitPoint, + TableReference, internal_datafusion_err, internal_err, not_impl_err, plan_err, }; use datafusion_execution::TaskContext; use datafusion_expr::dml::CopyTo; @@ -91,9 +91,9 @@ use datafusion_expr::logical_plan::{ use datafusion_expr::{ Accumulator, AggregateUDF, ColumnarValue, DmlStatement, ExprFunctionExt, ExprSchemable, LimitEffect, Literal, LogicalPlan, LogicalPlanBuilder, Operator, - PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, WriteOp, + PartitionEvaluator, RangePartitioning, Repartition, ScalarUDF, Signature, TryCast, + Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, WriteOp, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ @@ -3509,3 +3509,33 @@ async fn roundtrip_join_null_equality() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn roundtrip_range_partitioning() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let scan_plan = ctx.sql("SELECT * FROM t1").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::Range(RangePartitioning::try_new( + vec![col("a").sort(true, true)], + vec![SplitPoint::new(vec![ScalarValue::Int64(Some(2))])], + )?), + }); + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +}