Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions datafusion/proto-models/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,19 @@ message RepartitionNode {
oneof partition_method {
uint64 round_robin = 2;
HashRepartition hash = 3;
RangeRepartition range = 4;
}
}

message RangeSplitPoint {
repeated datafusion_common.ScalarValue value = 1;
}

message RangeRepartition {
repeated SortExprNode expr = 1;

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the physical plan field called this sort_expr, I could change this to match if needed?

repeated RangeSplitPoint split_points = 2;
}

message HashRepartition {
repeated LogicalExprNode hash_expr = 1;
uint64 partition_count = 2;
Expand Down
213 changes: 213 additions & 0 deletions datafusion/proto-models/src/generated/pbjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22350,6 +22350,206 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode {
deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for RangeRepartition {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut len = 0;
if !self.expr.is_empty() {
len += 1;
}
if !self.split_points.is_empty() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion.RangeRepartition", len)?;
if !self.expr.is_empty() {
struct_ser.serialize_field("expr", &self.expr)?;
}
if !self.split_points.is_empty() {
struct_ser.serialize_field("splitPoints", &self.split_points)?;
}
struct_ser.end()
}
}
impl<'de> serde::Deserialize<'de> for RangeRepartition {
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
"expr",
"split_points",
"splitPoints",
];

#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Expr,
SplitPoints,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GeneratedVisitor;

impl serde::de::Visitor<'_> for GeneratedVisitor {
type Value = GeneratedField;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "expected one of: {:?}", &FIELDS)
}

#[allow(unused_variables)]
fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E>
where
E: serde::de::Error,
{
match value {
"expr" => Ok(GeneratedField::Expr),
"splitPoints" | "split_points" => Ok(GeneratedField::SplitPoints),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(GeneratedVisitor)
}
}
struct GeneratedVisitor;
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = RangeRepartition;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct datafusion.RangeRepartition")
}

fn visit_map<V>(self, mut map_: V) -> std::result::Result<RangeRepartition, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut expr__ = None;
let mut split_points__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Expr => {
if expr__.is_some() {
return Err(serde::de::Error::duplicate_field("expr"));
}
expr__ = Some(map_.next_value()?);
}
GeneratedField::SplitPoints => {
if split_points__.is_some() {
return Err(serde::de::Error::duplicate_field("splitPoints"));
}
split_points__ = Some(map_.next_value()?);
}
}
}
Ok(RangeRepartition {
expr: expr__.unwrap_or_default(),
split_points: split_points__.unwrap_or_default(),
})
}
}
deserializer.deserialize_struct("datafusion.RangeRepartition", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for RangeSplitPoint {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut len = 0;
if !self.value.is_empty() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion.RangeSplitPoint", len)?;
if !self.value.is_empty() {
struct_ser.serialize_field("value", &self.value)?;
}
struct_ser.end()
}
}
impl<'de> serde::Deserialize<'de> for RangeSplitPoint {
#[allow(deprecated)]
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
"value",
];

#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Value,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
where
D: serde::Deserializer<'de>,
{
struct GeneratedVisitor;

impl serde::de::Visitor<'_> for GeneratedVisitor {
type Value = GeneratedField;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "expected one of: {:?}", &FIELDS)
}

#[allow(unused_variables)]
fn visit_str<E>(self, value: &str) -> std::result::Result<GeneratedField, E>
where
E: serde::de::Error,
{
match value {
"value" => Ok(GeneratedField::Value),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
}
deserializer.deserialize_identifier(GeneratedVisitor)
}
}
struct GeneratedVisitor;
impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
type Value = RangeSplitPoint;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct datafusion.RangeSplitPoint")
}

fn visit_map<V>(self, mut map_: V) -> std::result::Result<RangeSplitPoint, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut value__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Value => {
if value__.is_some() {
return Err(serde::de::Error::duplicate_field("value"));
}
value__ = Some(map_.next_value()?);
}
}
}
Ok(RangeSplitPoint {
value: value__.unwrap_or_default(),
})
}
}
deserializer.deserialize_struct("datafusion.RangeSplitPoint", FIELDS, GeneratedVisitor)
}
}
impl serde::Serialize for RecursionUnnestOption {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
Expand Down Expand Up @@ -22778,6 +22978,9 @@ impl serde::Serialize for RepartitionNode {
repartition_node::PartitionMethod::Hash(v) => {
struct_ser.serialize_field("hash", v)?;
}
repartition_node::PartitionMethod::Range(v) => {
struct_ser.serialize_field("range", v)?;
}
}
}
struct_ser.end()
Expand All @@ -22794,13 +22997,15 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode {
"round_robin",
"roundRobin",
"hash",
"range",
];

#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Input,
RoundRobin,
Hash,
Range,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) -> std::result::Result<GeneratedField, D::Error>
Expand All @@ -22825,6 +23030,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode {
"input" => Ok(GeneratedField::Input),
"roundRobin" | "round_robin" => Ok(GeneratedField::RoundRobin),
"hash" => Ok(GeneratedField::Hash),
"range" => Ok(GeneratedField::Range),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
}
Expand Down Expand Up @@ -22865,6 +23071,13 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode {
return Err(serde::de::Error::duplicate_field("hash"));
}
partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash)
;
}
GeneratedField::Range => {
if partition_method__.is_some() {
return Err(serde::de::Error::duplicate_field("range"));
}
partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Range)
;
}
}
Expand Down
16 changes: 15 additions & 1 deletion datafusion/proto-models/src/generated/prost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ pub struct SortNode {
pub struct RepartitionNode {
#[prost(message, optional, boxed, tag = "1")]
pub input: ::core::option::Option<::prost::alloc::boxed::Box<LogicalPlanNode>>,
#[prost(oneof = "repartition_node::PartitionMethod", tags = "2, 3")]
#[prost(oneof = "repartition_node::PartitionMethod", tags = "2, 3, 4")]
pub partition_method: ::core::option::Option<repartition_node::PartitionMethod>,
}
/// Nested message and enum types in `RepartitionNode`.
Expand All @@ -221,9 +221,23 @@ pub mod repartition_node {
RoundRobin(u64),
#[prost(message, tag = "3")]
Hash(super::HashRepartition),
#[prost(message, tag = "4")]
Range(super::RangeRepartition),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RangeSplitPoint {
#[prost(message, repeated, tag = "1")]
pub value: ::prost::alloc::vec::Vec<super::datafusion_common::ScalarValue>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RangeRepartition {
#[prost(message, repeated, tag = "1")]
pub expr: ::prost::alloc::vec::Vec<SortExprNode>,
#[prost(message, repeated, tag = "2")]
pub split_points: ::prost::alloc::vec::Vec<RangeSplitPoint>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct HashRepartition {
#[prost(message, repeated, tag = "1")]
pub hash_expr: ::prost::alloc::vec::Vec<LogicalExprNode>,
Expand Down
13 changes: 12 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, Field};
use datafusion_common::datatype::DataTypeExt;
use datafusion_common::{
NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference,
NullEquality, RecursionUnnestOption, Result, ScalarValue, SplitPoint, TableReference,
UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err,
};
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -899,3 +899,14 @@ fn parse_required_expr(
fn proto_error<S: Into<String>>(message: S) -> Error {
Error::General(message.into())
}

pub fn parse_protobuf_range_split_point(

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied this & serialize_range_split_point from the physical plan logic by @gene-bordegaray ! 😁

split_point: &protobuf::RangeSplitPoint,
) -> Result<SplitPoint, Error> {
let values = split_point
.value
.iter()
.map(ScalarValue::try_from)
.collect::<Result<_, Error>>()?;
Ok(SplitPoint::new(values))
}
31 changes: 25 additions & 6 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ use datafusion_datasource_json::file_format::{
use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory};
use datafusion_expr::dml::InsertOp;
use datafusion_expr::{
AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RecursiveQuery, SkipType,
TableSource, Unnest, WriteOp,
AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RangePartitioning,
RecursiveQuery, SkipType, TableSource, Unnest, WriteOp,
};
use datafusion_expr::{
DistinctOn, DropView, Expr, JoinConstraint, LogicalPlan, LogicalPlanBuilder,
Expand All @@ -76,6 +76,7 @@ use datafusion_expr::{
use datafusion_proto_common::protobuf_common;

use self::to_proto::{serialize_expr, serialize_exprs};
use crate::logical_plan::to_proto::serialize_range_split_point;
use crate::logical_plan::to_proto::serialize_sorts;
use datafusion_catalog::TableProvider;
use datafusion_catalog::default_table_source::{provider_as_source, source_as_provider};
Expand Down Expand Up @@ -745,6 +746,16 @@ impl AsLogicalPlan for LogicalPlanNode {
PartitionMethod::RoundRobin(partition_count) => {
Partitioning::RoundRobinBatch(*partition_count as usize)
}
PartitionMethod::Range(protobuf::RangeRepartition {
expr: pb_sort_expr,
split_points,
}) => Partitioning::Range(RangePartitioning::try_new(
from_proto::parse_sorts(pb_sort_expr, ctx, extension_codec)?,
split_points
.iter()
.map(from_proto::parse_protobuf_range_split_point)
.collect::<Result<Vec<_>, _>>()?,
)?),
};

LogicalPlanBuilder::from(input)
Expand Down Expand Up @@ -1754,10 +1765,18 @@ impl AsLogicalPlan for LogicalPlanNode {
Partitioning::RoundRobinBatch(partition_count) => {
PartitionMethod::RoundRobin(*partition_count as u64)
}
Partitioning::Range(_) => {
// TODO: Support range repartition protobuf serialization.
// Tracked by https://github.com/apache/datafusion/issues/22787
return not_impl_err!("Range repartition");
Partitioning::Range(range_partitioning) => {
let ordering = range_partitioning.ordering();
let split_points = range_partitioning
.split_points()
.iter()
.map(serialize_range_split_point)
.collect::<Result<Vec<_>, _>>()?;

PartitionMethod::Range(protobuf::RangeRepartition {
expr: serialize_sorts(ordering, extension_codec)?,
split_points,
})
}
Partitioning::DistributeBy(_) => {
return not_impl_err!("DistributeBy");
Expand Down
Loading