diff --git a/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/optimizers/pre_aggregation/pre_aggregations_compiler.rs b/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/optimizers/pre_aggregation/pre_aggregations_compiler.rs index 809392076d4c5..c39cbe8948528 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/optimizers/pre_aggregation/pre_aggregations_compiler.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/optimizers/pre_aggregation/pre_aggregations_compiler.rs @@ -510,7 +510,7 @@ impl PreAggregationsCompiler { "Pre-aggregation time dimension must be a dimension" )) })?; - if dimension.dimension_type() != "time" { + if !dimension.is_time() { return Err(CubeError::user(format!( "Pre-aggregation time dimension must be a dimension" ))); diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/filter/base_filter.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/filter/base_filter.rs index 3c1d9a8ed28da..ae883a154223b 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/filter/base_filter.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/filter/base_filter.rs @@ -146,15 +146,11 @@ impl BaseFilter { let filters_context = context.filters_context(); let symbol = self.member_evaluator(); let member_type = match symbol.as_ref() { - MemberSymbol::Dimension(dimension_symbol) => Some( - dimension_symbol - .definition() - .static_data() - .dimension_type - .clone(), - ), + MemberSymbol::Dimension(dimension_symbol) => { + Some(dimension_symbol.dimension_type().to_string()) + } MemberSymbol::Measure(measure_symbol) => { - Some(measure_symbol.measure_type().clone()) + Some(measure_symbol.measure_type().to_string()) } _ => None, }; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs index 8542c3eeef982..ebf54eac8548a 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs @@ -181,7 +181,7 @@ impl MultiStageAppliedState { m.resolve_reference_chain() }; if let Ok(dim) = symbol.as_dimension() { - if dim.dimension_type() == "time" { + if dim.is_time() { Some(symbol) } else { None diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/multi_stage_query_planner.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/multi_stage_query_planner.rs index 4478a5c6a4c3d..dd85eb489d0a5 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/multi_stage_query_planner.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/multi_stage_query_planner.rs @@ -16,6 +16,7 @@ use crate::planner::sql_evaluator::collectors::member_childs; use crate::planner::sql_evaluator::Case; use crate::planner::sql_evaluator::CaseSwitchDefinition; use crate::planner::sql_evaluator::CaseSwitchItem; +use crate::planner::sql_evaluator::MeasureKind; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::GranularityHelper; use crate::planner::QueryProperties; @@ -114,12 +115,10 @@ impl MultiStageQueryPlanner { resolved_multi_stage_dimensions: &mut HashSet, ) -> Result<(MultiStageInodeMember, bool), CubeError> { let inode = if let Ok(measure) = base_member.as_measure() { - let member_type = if measure.measure_type() == "rank" { - MultiStageInodeMemberType::Rank - } else if !measure.is_calculated() { - MultiStageInodeMemberType::Aggregate - } else { - MultiStageInodeMemberType::Calculate + let member_type = match measure.kind() { + MeasureKind::Rank => MultiStageInodeMemberType::Rank, + MeasureKind::Calculated(_) => MultiStageInodeMemberType::Calculate, + _ => MultiStageInodeMemberType::Aggregate, }; let time_shift = measure.time_shift().clone(); diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/collectors/has_multi_stage_members.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/collectors/has_multi_stage_members.rs index a8678f5017cea..4e092c62162a3 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/collectors/has_multi_stage_members.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/collectors/has_multi_stage_members.rs @@ -32,9 +32,7 @@ impl TraversalVisitor for HasMultiStageMembersCollector { MemberSymbol::Measure(s) => { if s.is_multi_stage() { self.has_multi_stage = true; - } else if !self.ignore_cumulative - && (s.is_rolling_window() || s.measure_type() == "runningTotal") - { + } else if !self.ignore_cumulative && s.is_cumulative() { self.has_multi_stage = true; } } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call_builder.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call_builder.rs index ff1ae4a8dcb89..f28b19ea07431 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call_builder.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call_builder.rs @@ -132,7 +132,7 @@ impl<'a> SqlCallBuilder<'a> { ) -> Option { if let Ok(member_symbol) = self.build_evaluator(¤t_cube_name, &path_tail[0]) { if let Ok(dimension) = member_symbol.as_dimension() { - if dimension.dimension_type() == "time" && path_tail.len() == 2 { + if dimension.is_time() && path_tail.len() == 2 { let granularity = &path_tail[1]; if let Ok(Some(granularity_obj)) = GranularityHelper::make_granularity_obj( self.cube_evaluator.clone(), diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/evaluate_sql.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/evaluate_sql.rs index cb5eb1e6a84aa..e4cfd834fab8c 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/evaluate_sql.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/evaluate_sql.rs @@ -38,42 +38,12 @@ impl SqlNode for EvaluateSqlNode { let res = visitor.apply(&ev.base_symbol(), node_processor.clone(), templates)?; Ok(res) } - MemberSymbol::Measure(ev) => { - let res = if ev.has_sql() { - ev.evaluate_sql( - visitor, - node_processor.clone(), - query_tools.clone(), - templates, - )? - } else if ev.pk_sqls().len() > 1 { - let pk_strings = ev - .pk_sqls() - .iter() - .map(|pk| -> Result<_, CubeError> { - let res = pk.eval( - &visitor, - node_processor.clone(), - query_tools.clone(), - templates, - )?; - templates.cast_to_string(&res) - }) - .collect::, _>>()?; - templates.concat_strings(&pk_strings)? - } else if ev.pk_sqls().len() == 1 { - let pk_sql = ev.pk_sqls().first().unwrap(); - pk_sql.eval( - &visitor, - node_processor.clone(), - query_tools.clone(), - templates, - )? - } else { - format!("*") - }; - Ok(res) - } + MemberSymbol::Measure(ev) => ev.evaluate_sql( + visitor, + node_processor.clone(), + query_tools.clone(), + templates, + ), MemberSymbol::CubeTable(ev) => ev.evaluate_sql( visitor, node_processor.clone(), diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs index 3e5323ac968ae..7c06e438ea43b 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs @@ -1,6 +1,6 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; -use crate::planner::sql_evaluator::MeasureSymbol; +use crate::planner::sql_evaluator::symbols::{AggregateWrap, MeasureSymbol}; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; @@ -32,12 +32,27 @@ impl FinalMeasureSqlNode { &self.input } - fn is_count_distinct(&self, symbol: &MeasureSymbol) -> bool { - symbol.measure_type() == "countDistinct" - || (symbol.measure_type() == "count" - && self - .rendered_as_multiplied_measures - .contains(&symbol.full_name())) + fn wrap_aggregate( + &self, + ev: &MeasureSymbol, + input: String, + templates: &PlanSqlTemplates, + ) -> Result { + let is_multiplied = self + .rendered_as_multiplied_measures + .contains(&ev.full_name()); + match ev.kind().aggregate_wrap(is_multiplied) { + AggregateWrap::PassThrough => Ok(input), + AggregateWrap::Function(name) => Ok(format!("{}({})", name, input)), + AggregateWrap::CountDistinct => templates.count_distinct(&input), + AggregateWrap::CountDistinctApprox => { + if self.count_approx_as_state { + templates.hll_init(input) + } else { + templates.count_distinct_approx(input) + } + } + } } } @@ -59,26 +74,7 @@ impl SqlNode for FinalMeasureSqlNode { node_processor.clone(), templates, )?; - - if ev.is_calculated() || ev.measure_type() == "numberAgg" { - input - } else if ev.measure_type() == "countDistinctApprox" { - if self.count_approx_as_state { - templates.hll_init(input)? - } else { - templates.count_distinct_approx(input)? - } - } else if self.is_count_distinct(ev) { - templates.count_distinct(&input)? - } else { - let measure_type = if ev.measure_type() == "runningTotal" { - "sum" - } else { - &ev.measure_type() - }; - - format!("{}({})", measure_type, input) - } + self.wrap_aggregate(ev, input, templates)? } _ => { return Err(CubeError::internal(format!( diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_pre_aggregation_measure.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_pre_aggregation_measure.rs index 1f9c75f728685..37a6dc000484f 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_pre_aggregation_measure.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_pre_aggregation_measure.rs @@ -2,6 +2,7 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; use crate::planner::sql_evaluator::sql_nodes::RenderReferences; use crate::planner::sql_evaluator::sql_nodes::RenderReferencesType; +use crate::planner::sql_evaluator::symbols::AggregateWrap; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; @@ -48,14 +49,14 @@ impl SqlNode for FinalPreAggregationMeasureSqlNode { table_ref, templates.quote_identifier(&column_name.name())? ); - if ev.measure_type() == "count" || ev.measure_type() == "sum" { - format!("sum({})", pre_aggregation_measure) - } else if ev.measure_type() == "countDistinctApprox" { - templates.count_distinct_approx(pre_aggregation_measure)? - } else if ev.measure_type() == "min" || ev.measure_type() == "max" { - format!("{}({})", ev.measure_type(), pre_aggregation_measure) - } else { - format!("sum({})", pre_aggregation_measure) + match ev.kind().pre_aggregate_wrap() { + AggregateWrap::CountDistinctApprox => { + templates.count_distinct_approx(pre_aggregation_measure)? + } + AggregateWrap::Function(name) => { + format!("{}({})", name, pre_aggregation_measure) + } + _ => format!("sum({})", pre_aggregation_measure), } } RenderReferencesType::LiteralValue(value) => { diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs index d6691f59d3942..c74ca07a708b0 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs @@ -1,5 +1,6 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::symbols::DimensionKind; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; @@ -32,31 +33,20 @@ impl SqlNode for GeoDimensionSqlNode { ) -> Result { let res = match node.as_ref() { MemberSymbol::Dimension(ev) => { - if ev.dimension_type() == "geo" { - if let (Some(latitude), Some(longitude)) = (ev.latitude(), ev.longitude()) { - let latitude_str = latitude.eval( - visitor, - node_processor.clone(), - query_tools.clone(), - templates, - )?; - let longitude_str = longitude.eval( - visitor, - node_processor.clone(), - query_tools.clone(), - templates, - )?; - templates.concat_strings(&vec![ - latitude_str, - format!("','"), - longitude_str, - ])? - } else { - return Err(CubeError::user(format!( - "Geo dimension '{}' must have latitude and longitude", - ev.full_name() - ))); - } + if let DimensionKind::Geo(geo) = ev.kind() { + let latitude_str = geo.latitude().eval( + visitor, + node_processor.clone(), + query_tools.clone(), + templates, + )?; + let longitude_str = geo.longitude().eval( + visitor, + node_processor.clone(), + query_tools.clone(), + templates, + )?; + templates.concat_strings(&vec![latitude_str, format!("','"), longitude_str])? } else { self.input.to_sql( visitor, diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs index d1200b8fae00d..c2b9f52e30671 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/multi_stage_rank.rs @@ -1,5 +1,6 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::symbols::MeasureKind; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; @@ -40,7 +41,7 @@ impl SqlNode for MultiStageRankNode { ) -> Result { let res = match node.as_ref() { MemberSymbol::Measure(m) => { - if m.is_multi_stage() && m.measure_type() == "rank" { + if m.is_multi_stage() && matches!(m.kind(), MeasureKind::Rank) { let order_by = if !m.measure_order_by().is_empty() { let sql = m .measure_order_by() diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs index c8546ecd85d9f..133bfde729376 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/rolling_window.rs @@ -1,5 +1,6 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::symbols::{AggregationType, MeasureKind}; use crate::planner::sql_evaluator::{MemberSymbol, SqlEvaluatorVisitor}; use crate::planner::sql_templates::PlanSqlTemplates; use cubenativeutils::CubeError; @@ -43,25 +44,35 @@ impl SqlNode for RollingWindowNode { node_processor.clone(), templates, )?; - if m.measure_type() == "countDistinctApprox" { - templates.hll_cardinality_merge(input)? - } else { - if m.measure_type() == "sum" - || m.measure_type() == "count" - || m.measure_type() == "runningTotal" + match m.kind() { + MeasureKind::Aggregated(a) + if a.agg_type() == AggregationType::CountDistinctApprox => { - format!("sum({})", input) - } else if m.measure_type() == "min" || m.measure_type() == "max" { - format!("{}({})", m.measure_type(), input) - } else { - self.default_processor.to_sql( + templates.hll_cardinality_merge(input)? + } + MeasureKind::Count(_) => format!("sum({})", input), + MeasureKind::Aggregated(a) => match a.agg_type() { + AggregationType::Sum | AggregationType::RunningTotal => { + format!("sum({})", input) + } + AggregationType::Min | AggregationType::Max => { + format!("{}({})", a.agg_type().as_str(), input) + } + _ => self.default_processor.to_sql( visitor, node, query_tools.clone(), node_processor, templates, - )? - } + )?, + }, + _ => self.default_processor.to_sql( + visitor, + node, + query_tools.clone(), + node_processor, + templates, + )?, } } else { self.default_processor.to_sql( diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs index c3b806276199e..1c0a37f9aa8da 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs @@ -41,7 +41,7 @@ impl SqlNode for TimeShiftSqlNode { )?; let res = match node.as_ref() { MemberSymbol::Dimension(ev) => { - if !ev.is_reference() && ev.dimension_type() == "time" { + if !ev.is_reference() && ev.is_time() { if let Some(shift) = self.shifts.dimensions_shifts.get(&ev.full_name()) { let shift = shift.interval.clone().unwrap().to_sql(); // Common time shifts should always have an interval let res = templates.add_timestamp_interval(input, shift)?; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs index 305e58a881c93..0dbbd94539de4 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/ungroupped_query_final_measure.rs @@ -1,5 +1,6 @@ use super::SqlNode; use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::symbols::{AggregationType, MeasureKind}; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; @@ -43,10 +44,15 @@ impl SqlNode for UngroupedQueryFinalMeasureSqlNode { if input == "*" { "1".to_string() } else { - if ev.measure_type() == "count" - || ev.measure_type() == "countDistinct" - || ev.measure_type() == "countDistinctApprox" - { + let is_count_like = match ev.kind() { + MeasureKind::Count(_) => true, + MeasureKind::Aggregated(a) => matches!( + a.agg_type(), + AggregationType::CountDistinct | AggregationType::CountDistinctApprox + ), + _ => false, + }; + if is_count_like { format!("CASE WHEN ({}) IS NOT NULL THEN 1 END", input) //TODO templates!! } else { input diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/aggregation_type.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/aggregation_type.rs new file mode 100644 index 0000000000000..346c3a1fdec94 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/aggregation_type.rs @@ -0,0 +1,177 @@ +use cubenativeutils::CubeError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AggregationType { + Sum, + Avg, + Min, + Max, + CountDistinct, + CountDistinctApprox, + NumberAgg, + RunningTotal, +} + +impl AggregationType { + pub fn from_str(s: &str) -> Result { + match s { + "sum" => Ok(Self::Sum), + "avg" => Ok(Self::Avg), + "min" => Ok(Self::Min), + "max" => Ok(Self::Max), + "countDistinct" | "count_distinct" => Ok(Self::CountDistinct), + "countDistinctApprox" | "count_distinct_approx" => Ok(Self::CountDistinctApprox), + "numberAgg" | "number_agg" => Ok(Self::NumberAgg), + "runningTotal" | "running_total" => Ok(Self::RunningTotal), + other => Err(CubeError::user(format!( + "Unknown aggregation type: '{}'", + other + ))), + } + } + + pub fn is_additive(&self) -> bool { + matches!( + self, + Self::Sum | Self::Min | Self::Max | Self::CountDistinctApprox | Self::RunningTotal + ) + } + + pub fn is_distinct(&self) -> bool { + matches!(self, Self::CountDistinct | Self::CountDistinctApprox) + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::Sum => "sum", + Self::Avg => "avg", + Self::Min => "min", + Self::Max => "max", + Self::CountDistinct => "countDistinct", + Self::CountDistinctApprox => "countDistinctApprox", + Self::NumberAgg => "numberAgg", + Self::RunningTotal => "runningTotal", + } + } +} + +impl TryFrom<&str> for AggregationType { + type Error = CubeError; + + fn try_from(s: &str) -> Result { + Self::from_str(s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_str_camel_case() { + assert_eq!( + AggregationType::from_str("sum").unwrap(), + AggregationType::Sum + ); + assert_eq!( + AggregationType::from_str("avg").unwrap(), + AggregationType::Avg + ); + assert_eq!( + AggregationType::from_str("min").unwrap(), + AggregationType::Min + ); + assert_eq!( + AggregationType::from_str("max").unwrap(), + AggregationType::Max + ); + assert_eq!( + AggregationType::from_str("countDistinct").unwrap(), + AggregationType::CountDistinct + ); + assert_eq!( + AggregationType::from_str("countDistinctApprox").unwrap(), + AggregationType::CountDistinctApprox + ); + assert_eq!( + AggregationType::from_str("numberAgg").unwrap(), + AggregationType::NumberAgg + ); + assert_eq!( + AggregationType::from_str("runningTotal").unwrap(), + AggregationType::RunningTotal + ); + } + + #[test] + fn test_from_str_snake_case() { + assert_eq!( + AggregationType::from_str("count_distinct").unwrap(), + AggregationType::CountDistinct + ); + assert_eq!( + AggregationType::from_str("count_distinct_approx").unwrap(), + AggregationType::CountDistinctApprox + ); + assert_eq!( + AggregationType::from_str("number_agg").unwrap(), + AggregationType::NumberAgg + ); + assert_eq!( + AggregationType::from_str("running_total").unwrap(), + AggregationType::RunningTotal + ); + } + + #[test] + fn test_is_additive() { + assert!(AggregationType::Sum.is_additive()); + assert!(AggregationType::Min.is_additive()); + assert!(AggregationType::Max.is_additive()); + assert!(!AggregationType::Avg.is_additive()); + assert!(!AggregationType::CountDistinct.is_additive()); + assert!(AggregationType::CountDistinctApprox.is_additive()); + assert!(!AggregationType::NumberAgg.is_additive()); + assert!(AggregationType::RunningTotal.is_additive()); + } + + #[test] + fn test_is_distinct() { + assert!(AggregationType::CountDistinct.is_distinct()); + assert!(AggregationType::CountDistinctApprox.is_distinct()); + assert!(!AggregationType::Sum.is_distinct()); + assert!(!AggregationType::Avg.is_distinct()); + assert!(!AggregationType::Min.is_distinct()); + assert!(!AggregationType::Max.is_distinct()); + assert!(!AggregationType::NumberAgg.is_distinct()); + assert!(!AggregationType::RunningTotal.is_distinct()); + } + + #[test] + fn test_as_str_round_trip() { + let variants = [ + AggregationType::Sum, + AggregationType::Avg, + AggregationType::Min, + AggregationType::Max, + AggregationType::CountDistinct, + AggregationType::CountDistinctApprox, + AggregationType::NumberAgg, + AggregationType::RunningTotal, + ]; + for v in &variants { + let s = v.as_str(); + let parsed = AggregationType::from_str(s).unwrap(); + assert_eq!(*v, parsed); + } + } + + #[test] + fn test_try_from() { + let result: Result = "sum".try_into(); + assert_eq!(result.unwrap(), AggregationType::Sum); + + let result: Result = "unknown".try_into(); + assert!(result.is_err()); + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/case.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/case.rs index 9123e7612e0fd..ae36170b4c29f 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/case.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/case.rs @@ -231,8 +231,8 @@ impl CaseSwitchDefinition { fn get_switch_values(&self) -> Option> { if let CaseSwitchItem::Member(member) = &self.switch { if let Ok(switch_dim) = member.as_dimension() { - if switch_dim.dimension_type() == "switch" { - return Some(switch_dim.values().clone()); + if switch_dim.is_switch() { + return Some(switch_dim.values().to_vec()); } } } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/dimension_type.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/dimension_type.rs new file mode 100644 index 0000000000000..a0a4c435e16e3 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/dimension_type.rs @@ -0,0 +1,52 @@ +use cubenativeutils::CubeError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DimensionType { + String, + Number, + Boolean, + Time, +} + +impl DimensionType { + pub fn from_str(s: &str) -> Result { + match s { + "string" => Ok(Self::String), + "number" => Ok(Self::Number), + "boolean" => Ok(Self::Boolean), + "time" => Ok(Self::Time), + other => Err(CubeError::user(format!( + "Unknown dimension type: '{}'", + other + ))), + } + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::String => "string", + Self::Number => "number", + Self::Boolean => "boolean", + Self::Time => "time", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_str_round_trip() { + for s in &["string", "number", "boolean", "time"] { + let dt = DimensionType::from_str(s).unwrap(); + assert_eq!(dt.as_str(), *s); + } + } + + #[test] + fn test_unknown_type_error() { + let result = DimensionType::from_str("unknown"); + assert!(result.is_err()); + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/mod.rs index 55173beb6c8cc..ed8bf27a461d2 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/mod.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/mod.rs @@ -1,7 +1,11 @@ +mod aggregation_type; mod case; +mod dimension_type; mod static_filter; mod symbol_path; +pub use aggregation_type::*; pub use case::*; +pub use dimension_type::*; pub use static_filter::*; pub use symbol_path::*; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/static_filter.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/static_filter.rs index 5c8ddac81c12e..58291ef702da0 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/static_filter.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/common/static_filter.rs @@ -19,7 +19,7 @@ pub fn find_value_restriction( pub fn get_filtered_values(symbol: &Rc, filter: &Option) -> Vec { if let Ok(dim) = symbol.as_dimension() { - if dim.dimension_type() == "switch" { + if dim.is_switch() { if let Some(filter) = filter { if let Some(values) = find_value_restriction(&filter.items, symbol) { let res = dim @@ -32,7 +32,7 @@ pub fn get_filtered_values(symbol: &Rc, filter: &Option) - } } } - return dim.values().clone(); + return dim.values().to_vec(); } vec![] diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/case_dimension.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/case_dimension.rs new file mode 100644 index 0000000000000..cbbbab3df9f6f --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/case_dimension.rs @@ -0,0 +1,109 @@ +use super::super::common::{Case, DimensionType}; +use super::super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub struct CaseDimension { + dimension_type: DimensionType, + case: Case, + member_sql: Option>, +} + +impl CaseDimension { + pub fn new(dimension_type: DimensionType, case: Case, member_sql: Option>) -> Self { + Self { + dimension_type, + case, + member_sql, + } + } + + pub fn dimension_type(&self) -> &DimensionType { + &self.dimension_type + } + + pub fn case(&self) -> &Case { + &self.case + } + + pub fn member_sql(&self) -> Option<&Rc> { + self.member_sql.as_ref() + } + + pub fn replace_case(&self, new_case: Case) -> Self { + Self { + dimension_type: self.dimension_type, + case: new_case, + member_sql: self.member_sql.clone(), + } + } + + pub fn evaluate_sql( + &self, + full_name: &str, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + if let Some(member_sql) = &self.member_sql { + member_sql.eval(visitor, node_processor, query_tools, templates) + } else { + Err(CubeError::internal(format!( + "Dimension {} hasn't sql evaluator", + full_name + ))) + } + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + if let Some(member_sql) = &self.member_sql { + member_sql.extract_symbol_deps(&mut deps); + } + self.case.extract_symbol_deps(&mut deps); + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + if let Some(member_sql) = &self.member_sql { + member_sql.extract_symbol_deps_with_path(&mut deps); + } + self.case.extract_symbol_deps_with_path(&mut deps); + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + let member_sql = if let Some(sql) = &self.member_sql { + Some(sql.apply_recursive(f)?) + } else { + None + }; + Ok(Self { + dimension_type: self.dimension_type, + case: self.case.apply_to_deps(f)?, + member_sql, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + Box::new(self.member_sql.iter().chain(self.case.iter_sql_calls())) + } + + pub fn is_owned_by_cube(&self) -> bool { + let mut owned = false; + if let Some(sql) = &self.member_sql { + owned |= sql.is_owned_by_cube(); + } + owned |= self.case.is_owned_by_cube(); + owned + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/geo.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/geo.rs new file mode 100644 index 0000000000000..4122699e6ab2b --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/geo.rs @@ -0,0 +1,59 @@ +use super::super::MemberSymbol; +use crate::planner::sql_evaluator::SqlCall; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub struct GeoDimension { + latitude: Rc, + longitude: Rc, +} + +impl GeoDimension { + pub fn new(latitude: Rc, longitude: Rc) -> Self { + Self { + latitude, + longitude, + } + } + + pub fn latitude(&self) -> &Rc { + &self.latitude + } + + pub fn longitude(&self) -> &Rc { + &self.longitude + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + self.latitude.extract_symbol_deps(&mut deps); + self.longitude.extract_symbol_deps(&mut deps); + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + self.latitude.extract_symbol_deps_with_path(&mut deps); + self.longitude.extract_symbol_deps_with_path(&mut deps); + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + Ok(Self { + latitude: self.latitude.apply_recursive(f)?, + longitude: self.longitude.apply_recursive(f)?, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + Box::new(std::iter::once(&self.latitude).chain(std::iter::once(&self.longitude))) + } + + pub fn is_owned_by_cube(&self) -> bool { + self.latitude.is_owned_by_cube() || self.longitude.is_owned_by_cube() + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/mod.rs new file mode 100644 index 0000000000000..9d7c7c4a47063 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/mod.rs @@ -0,0 +1,135 @@ +mod case_dimension; +mod geo; +mod regular; +mod switch; + +pub use case_dimension::*; +pub use geo::*; +pub use regular::*; +pub use switch::*; + +use super::common::DimensionType; +use super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub enum DimensionKind { + Regular(RegularDimension), + Geo(GeoDimension), + Switch(SwitchDimension), + Case(CaseDimension), +} + +impl DimensionKind { + pub fn evaluate_sql( + &self, + name: &str, + full_name: &str, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + match self { + Self::Regular(r) => r.evaluate_sql(visitor, node_processor, query_tools, templates), + Self::Geo(_) => Err(CubeError::internal(format!( + "Geo dimension {} doesn't support evaluate_sql directly", + full_name + ))), + Self::Switch(s) => { + s.evaluate_sql(name, visitor, node_processor, query_tools, templates) + } + Self::Case(c) => { + c.evaluate_sql(full_name, visitor, node_processor, query_tools, templates) + } + } + } + + pub fn get_dependencies(&self) -> Vec> { + match self { + Self::Regular(r) => r.get_dependencies(), + Self::Geo(g) => g.get_dependencies(), + Self::Switch(s) => s.get_dependencies(), + Self::Case(c) => c.get_dependencies(), + } + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + match self { + Self::Regular(r) => r.get_dependencies_with_path(), + Self::Geo(g) => g.get_dependencies_with_path(), + Self::Switch(s) => s.get_dependencies_with_path(), + Self::Case(c) => c.get_dependencies_with_path(), + } + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + Ok(match self { + Self::Regular(r) => Self::Regular(r.apply_to_deps(f)?), + Self::Geo(g) => Self::Geo(g.apply_to_deps(f)?), + Self::Switch(s) => Self::Switch(s.apply_to_deps(f)?), + Self::Case(c) => Self::Case(c.apply_to_deps(f)?), + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + match self { + Self::Regular(r) => r.iter_sql_calls(), + Self::Geo(g) => g.iter_sql_calls(), + Self::Switch(s) => s.iter_sql_calls(), + Self::Case(c) => c.iter_sql_calls(), + } + } + + pub fn is_owned_by_cube(&self) -> bool { + match self { + Self::Regular(r) => r.is_owned_by_cube(), + Self::Geo(g) => g.is_owned_by_cube(), + Self::Switch(s) => s.is_owned_by_cube(), + Self::Case(c) => c.is_owned_by_cube(), + } + } + + pub fn is_time(&self) -> bool { + match self { + Self::Regular(r) => *r.dimension_type() == DimensionType::Time, + Self::Case(c) => *c.dimension_type() == DimensionType::Time, + _ => false, + } + } + + pub fn is_geo(&self) -> bool { + matches!(self, Self::Geo(_)) + } + + pub fn is_switch(&self) -> bool { + matches!(self, Self::Switch(_)) + } + + pub fn is_case(&self) -> bool { + matches!(self, Self::Case(_)) + } + + pub fn is_calc_group(&self) -> bool { + match self { + Self::Switch(s) => s.is_calc_group(), + _ => false, + } + } + + pub fn dimension_type_str(&self) -> &str { + match self { + Self::Regular(r) => r.dimension_type().as_str(), + Self::Geo(_) => "geo", + Self::Switch(_) => "switch", + Self::Case(c) => c.dimension_type().as_str(), + } + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/regular.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/regular.rs new file mode 100644 index 0000000000000..7eb393b13e692 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/regular.rs @@ -0,0 +1,71 @@ +use super::super::common::DimensionType; +use super::super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub struct RegularDimension { + dimension_type: DimensionType, + member_sql: Rc, +} + +impl RegularDimension { + pub fn new(dimension_type: DimensionType, member_sql: Rc) -> Self { + Self { + dimension_type, + member_sql, + } + } + + pub fn dimension_type(&self) -> &DimensionType { + &self.dimension_type + } + + pub fn member_sql(&self) -> &Rc { + &self.member_sql + } + + pub fn evaluate_sql( + &self, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + self.member_sql + .eval(visitor, node_processor, query_tools, templates) + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + self.member_sql.extract_symbol_deps(&mut deps); + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + self.member_sql.extract_symbol_deps_with_path(&mut deps); + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + Ok(Self { + dimension_type: self.dimension_type, + member_sql: self.member_sql.apply_recursive(f)?, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + Box::new(std::iter::once(&self.member_sql)) + } + + pub fn is_owned_by_cube(&self) -> bool { + self.member_sql.is_owned_by_cube() + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/switch.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/switch.rs new file mode 100644 index 0000000000000..a4a101fc12b19 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_kinds/switch.rs @@ -0,0 +1,84 @@ +use super::super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub struct SwitchDimension { + values: Vec, + member_sql: Option>, +} + +impl SwitchDimension { + pub fn new(values: Vec, member_sql: Option>) -> Self { + Self { values, member_sql } + } + + pub fn values(&self) -> &[String] { + &self.values + } + + pub fn member_sql(&self) -> Option<&Rc> { + self.member_sql.as_ref() + } + + pub fn is_calc_group(&self) -> bool { + self.member_sql.is_none() + } + + pub fn evaluate_sql( + &self, + name: &str, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + if let Some(member_sql) = &self.member_sql { + member_sql.eval(visitor, node_processor, query_tools, templates) + } else { + Ok(templates.quote_identifier(name)?) + } + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + if let Some(member_sql) = &self.member_sql { + member_sql.extract_symbol_deps(&mut deps); + } + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + if let Some(member_sql) = &self.member_sql { + member_sql.extract_symbol_deps_with_path(&mut deps); + } + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + let member_sql = if let Some(sql) = &self.member_sql { + Some(sql.apply_recursive(f)?) + } else { + None + }; + Ok(Self { + values: self.values.clone(), + member_sql, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + Box::new(self.member_sql.iter()) + } + + pub fn is_owned_by_cube(&self) -> bool { + false + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_symbol.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_symbol.rs index 19b4240424b30..bce97f970cbf3 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_symbol.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/dimension_symbol.rs @@ -1,6 +1,9 @@ use super::common::Case; +use super::dimension_kinds::{ + CaseDimension, DimensionKind, GeoDimension, RegularDimension, SwitchDimension, +}; use super::SymbolPath; -use super::{MemberSymbol, SymbolFactory}; +use super::{DimensionType, MemberSymbol, SymbolFactory}; use crate::cube_bridge::dimension_definition::DimensionDefinition; use crate::cube_bridge::evaluator::CubeEvaluator; use crate::cube_bridge::member_sql::MemberSql; @@ -24,21 +27,14 @@ pub struct CalendarDimensionTimeShift { pub struct DimensionSymbol { cube: Rc, name: String, - dimension_type: String, + kind: DimensionKind, alias: String, - member_sql: Option>, - latitude: Option>, - longitude: Option>, - values: Vec, - case: Option, - definition: Rc, is_reference: bool, // Symbol is a direct reference to another symbol without any calculations is_view: bool, add_group_by: Option>>, time_shift: Vec, time_shift_pk_full_name: Option, is_self_time_shift_pk: bool, // If the dimension itself is a primary key and has time shifts, we can not reevaluate itself again while processing time shifts to avoid infinite recursion. So we raise this flag instead. - owned_by_cube: bool, is_multi_stage: bool, is_sub_query: bool, propagate_filters_to_sub_query: bool, @@ -48,21 +44,14 @@ impl DimensionSymbol { pub fn new( cube: Rc, name: String, - dimension_type: String, + kind: DimensionKind, alias: String, - member_sql: Option>, is_reference: bool, is_view: bool, - latitude: Option>, - longitude: Option>, - values: Vec, - case: Option, - definition: Rc, add_group_by: Option>>, time_shift: Vec, time_shift_pk_full_name: Option, is_self_time_shift_pk: bool, - owned_by_cube: bool, is_multi_stage: bool, is_sub_query: bool, propagate_filters_to_sub_query: bool, @@ -70,21 +59,14 @@ impl DimensionSymbol { Rc::new(Self { cube, name, - dimension_type, + kind, alias, - member_sql, is_reference, - latitude, - longitude, - values, - definition, - add_group_by, - case, is_view, + add_group_by, time_shift, time_shift_pk_full_name, is_self_time_shift_pk, - owned_by_cube, is_multi_stage, is_sub_query, propagate_filters_to_sub_query, @@ -98,53 +80,53 @@ impl DimensionSymbol { query_tools: Rc, templates: &PlanSqlTemplates, ) -> Result { - if self.member_sql.is_none() && self.dimension_type == "switch" { - Ok(templates.quote_identifier(&self.name)?) //We don't return cube_name - - //it should be added in - //autoprefix processing - } else if let Some(member_sql) = &self.member_sql { - let sql = member_sql.eval(visitor, node_processor, query_tools, templates)?; - Ok(sql) - } else { - Err(CubeError::internal(format!( - "Dimension {} hasn't sql evaluator", - self.full_name() - ))) - } + self.kind.evaluate_sql( + &self.name, + &self.full_name(), + visitor, + node_processor, + query_tools, + templates, + ) } pub fn is_calc_group(&self) -> bool { - self.member_sql.is_none() && self.dimension_type == "switch" + self.kind.is_calc_group() } - pub fn values(&self) -> &Vec { - &self.values + pub fn values(&self) -> &[String] { + match &self.kind { + DimensionKind::Switch(s) => s.values(), + _ => &[], + } } pub(super) fn replace_case(&self, new_case: Case) -> Rc { let mut new = self.clone(); if new_case.is_single_value() { - //FIXME - Hack: we don’t treat a single-element case as a multi-stage dimension + //FIXME - Hack: we don't treat a single-element case as a multi-stage dimension new.is_multi_stage = false; } - new.case = Some(new_case); + if let DimensionKind::Case(ref c) = new.kind { + new.kind = DimensionKind::Case(c.replace_case(new_case)); + } Rc::new(new) } - pub fn latitude(&self) -> Option> { - self.latitude.clone() - } - - pub fn longitude(&self) -> Option> { - self.longitude.clone() - } - pub fn case(&self) -> Option<&Case> { - self.case.as_ref() + match &self.kind { + DimensionKind::Case(c) => Some(c.case()), + _ => None, + } } - pub fn member_sql(&self) -> &Option> { - &self.member_sql + pub fn member_sql(&self) -> Option<&Rc> { + match &self.kind { + DimensionKind::Regular(r) => Some(r.member_sql()), + DimensionKind::Switch(s) => s.member_sql(), + DimensionKind::Case(c) => c.member_sql(), + DimensionKind::Geo(_) => None, + } } pub fn time_shift(&self) -> &Vec { @@ -164,7 +146,7 @@ impl DimensionSymbol { } pub fn owned_by_cube(&self) -> bool { - self.owned_by_cube + !self.is_multi_stage && !self.kind.is_switch() && self.kind.is_owned_by_cube() } pub fn is_multi_stage(&self) -> bool { @@ -179,8 +161,28 @@ impl DimensionSymbol { &self.add_group_by } - pub fn dimension_type(&self) -> &String { - &self.dimension_type + pub fn dimension_type(&self) -> &str { + self.kind.dimension_type_str() + } + + pub fn kind(&self) -> &DimensionKind { + &self.kind + } + + pub fn is_time(&self) -> bool { + self.kind.is_time() + } + + pub fn is_geo(&self) -> bool { + self.kind.is_geo() + } + + pub fn is_switch(&self) -> bool { + self.kind.is_switch() + } + + pub fn is_case(&self) -> bool { + self.kind.is_case() } pub fn propagate_filters_to_sub_query(&self) -> bool { @@ -211,73 +213,24 @@ impl DimensionSymbol { f: &F, ) -> Result, CubeError> { let mut result = self.clone(); - if let Some(member_sql) = &self.member_sql { - result.member_sql = Some(member_sql.apply_recursive(f)?); - } - if let Some(latitude) = &self.latitude { - result.latitude = Some(latitude.apply_recursive(f)?); - } - if let Some(longitude) = &self.longitude { - result.longitude = Some(longitude.apply_recursive(f)?); - } - - if let Some(case) = &self.case { - result.case = Some(case.apply_to_deps(f)?) - } - + result.kind = self.kind.apply_to_deps(f)?; Ok(MemberSymbol::new_dimension(Rc::new(result))) } pub fn iter_sql_calls(&self) -> Box> + '_> { - let result = self - .member_sql - .iter() - .chain(self.latitude.iter()) - .chain(self.longitude.iter()) - .chain(self.case.iter().flat_map(|case| case.iter_sql_calls())); - Box::new(result) + self.kind.iter_sql_calls() } pub fn get_dependencies(&self) -> Vec> { - let mut deps = vec![]; - if let Some(member_sql) = &self.member_sql { - member_sql.extract_symbol_deps(&mut deps); - } - if let Some(member_sql) = &self.latitude { - member_sql.extract_symbol_deps(&mut deps); - } - if let Some(member_sql) = &self.longitude { - member_sql.extract_symbol_deps(&mut deps); - } - if let Some(case) = &self.case { - case.extract_symbol_deps(&mut deps); - } - deps + self.kind.get_dependencies() } pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { - let mut deps = vec![]; - if let Some(member_sql) = &self.member_sql { - member_sql.extract_symbol_deps_with_path(&mut deps); - } - if let Some(member_sql) = &self.latitude { - member_sql.extract_symbol_deps_with_path(&mut deps); - } - if let Some(member_sql) = &self.longitude { - member_sql.extract_symbol_deps_with_path(&mut deps); - } - if let Some(case) = &self.case { - case.extract_symbol_deps_with_path(&mut deps); - } - deps + self.kind.get_dependencies_with_path() } pub fn cube_name(&self) -> &String { - &self.cube.cube_name() - } - - pub fn definition(&self) -> &Rc { - &self.definition + self.cube.cube_name() } pub fn join_map(&self) -> &Option>> { @@ -367,29 +320,7 @@ impl SymbolFactory for DimensionSymbolFactory { None }; - let is_sql_direct_ref = if let Some(sql) = &sql { - sql.is_direct_reference() - } else { - false - }; - - let (latitude, longitude) = if dimension_type == "geo" { - if let (Some(latitude_item), Some(longitude_item)) = - (definition.latitude()?, definition.longitude()?) - { - let latitude = compiler.compile_sql_call(path.cube_name(), latitude_item.sql()?)?; - let longitude = - compiler.compile_sql_call(path.cube_name(), longitude_item.sql()?)?; - (Some(latitude), Some(longitude)) - } else { - return Err(CubeError::user(format!( - "Geo dimension '{}'must have latitude and longitude", - path.full_name() - ))); - } - } else { - (None, None) - }; + let is_sql_direct_ref = sql.as_ref().is_some_and(|s| s.is_direct_reference()); let case = if let Some(native_case) = definition.case()? { Some(Case::try_new(path.cube_name(), native_case, compiler)?) @@ -467,12 +398,6 @@ impl SymbolFactory for DimensionSymbolFactory { None }; - let values = if definition.static_data().dimension_type == "switch" { - definition.static_data().values.clone().unwrap_or_default() - } else { - vec![] - }; - let add_group_by = if let Some(add_group_by) = &definition.static_data().add_group_by_references { let symbols = add_group_by @@ -487,31 +412,49 @@ impl SymbolFactory for DimensionSymbolFactory { let is_sub_query = definition.static_data().sub_query.unwrap_or(false); let is_multi_stage = definition.static_data().multi_stage.unwrap_or(false); - let owned_by_cube = if is_multi_stage || dimension_type == "switch" { - false - } else { - let mut owned = false; - if let Some(sql) = &sql { - owned |= sql.is_owned_by_cube(); - } - if let Some(sql) = &latitude { - owned |= sql.is_owned_by_cube(); - } - if let Some(sql) = &longitude { - owned |= sql.is_owned_by_cube(); + let kind = if let Some(case_val) = case { + let dim_type = DimensionType::from_str(&dimension_type)?; + DimensionKind::Case(CaseDimension::new(dim_type, case_val, sql)) + } else if dimension_type == "geo" { + if let (Some(lat_item), Some(lon_item)) = + (definition.latitude()?, definition.longitude()?) + { + let latitude = compiler.compile_sql_call(path.cube_name(), lat_item.sql()?)?; + let longitude = compiler.compile_sql_call(path.cube_name(), lon_item.sql()?)?; + DimensionKind::Geo(GeoDimension::new(latitude, longitude)) + } else { + return Err(CubeError::user(format!( + "Geo dimension '{}' must have latitude and longitude", + path.full_name() + ))); } - if let Some(case) = &case { - owned |= case.is_owned_by_cube(); + } else if dimension_type == "switch" { + let values = definition.static_data().values.clone().unwrap_or_default(); + DimensionKind::Switch(SwitchDimension::new(values, sql)) + } else { + let dim_type = DimensionType::from_str(&dimension_type)?; + match sql { + Some(sql) => DimensionKind::Regular(RegularDimension::new(dim_type, sql)), + None => { + return Err(CubeError::internal(format!( + "Dimension '{}' must have sql", + path.full_name() + ))); + } } - owned + }; + + let owned_by_cube = if is_multi_stage || kind.is_switch() { + false + } else { + kind.is_owned_by_cube() }; let is_reference = (is_view && is_sql_direct_ref) || (!owned_by_cube && !is_sub_query && is_sql_direct_ref - && case.is_none() - && latitude.is_none() - && longitude.is_none() + && !kind.is_case() + && !kind.is_geo() && !is_multi_stage); let propagate_filters_to_sub_query = definition @@ -526,21 +469,14 @@ impl SymbolFactory for DimensionSymbolFactory { let symbol = MemberSymbol::new_dimension(DimensionSymbol::new( cube_symbol, path.symbol_name().clone(), - dimension_type, + kind, alias, - sql, is_reference, is_view, - latitude, - longitude, - values, - case, - definition, add_group_by, time_shift, time_shift_pk, is_self_time_shift_pk, - owned_by_cube, is_multi_stage, is_sub_query, propagate_filters_to_sub_query, @@ -576,33 +512,23 @@ impl SymbolFactory for DimensionSymbolFactory { impl crate::utils::debug::DebugSql for DimensionSymbol { fn debug_sql(&self, expand_deps: bool) -> String { - if let Some(case) = &self.case { - return case.debug_sql(expand_deps); - } - - if self.dimension_type == "geo" { - let lat = self - .latitude - .as_ref() - .map(|sql| sql.debug_sql(expand_deps)) - .unwrap_or_else(|| "{missing_latitude}".to_string()); - let lon = self - .longitude - .as_ref() - .map(|sql| sql.debug_sql(expand_deps)) - .unwrap_or_else(|| "{missing_longitude}".to_string()); - return format!("GEO({}, {})", lat, lon); - } - - if self.dimension_type == "switch" && self.member_sql.is_none() { - return format!("SWITCH({})", self.full_name()); + match &self.kind { + DimensionKind::Case(c) => c.case().debug_sql(expand_deps), + DimensionKind::Geo(g) => { + let lat = g.latitude().debug_sql(expand_deps); + let lon = g.longitude().debug_sql(expand_deps); + format!("GEO({}, {})", lat, lon) + } + DimensionKind::Switch(s) if s.is_calc_group() => { + format!("SWITCH({})", self.full_name()) + } + _ => { + if let Some(sql) = self.member_sql() { + sql.debug_sql(expand_deps) + } else { + "".to_string() + } + } } - - let res = if let Some(sql) = &self.member_sql { - sql.debug_sql(expand_deps) - } else { - "".to_string() - }; - res } } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/aggregated.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/aggregated.rs new file mode 100644 index 0000000000000..fc4b9dd914242 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/aggregated.rs @@ -0,0 +1,92 @@ +use super::super::super::MemberSymbol; +use super::super::common::AggregationType; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub struct AggregatedMeasure { + agg_type: AggregationType, + member_sql: Option>, +} + +impl AggregatedMeasure { + pub fn new(agg_type: AggregationType, member_sql: Rc) -> Self { + Self { + agg_type, + member_sql: Some(member_sql), + } + } + + pub fn new_without_sql(agg_type: AggregationType) -> Self { + Self { + agg_type, + member_sql: None, + } + } + + pub fn agg_type(&self) -> AggregationType { + self.agg_type + } + + pub fn member_sql(&self) -> Option<&Rc> { + self.member_sql.as_ref() + } + + pub fn evaluate_sql( + &self, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + match &self.member_sql { + Some(sql) => sql.eval(visitor, node_processor, query_tools, templates), + None => Err(CubeError::internal( + "Aggregated measure without sql cannot be evaluated directly".to_string(), + )), + } + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + if let Some(sql) = &self.member_sql { + sql.extract_symbol_deps(&mut deps); + } + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + if let Some(sql) = &self.member_sql { + sql.extract_symbol_deps_with_path(&mut deps); + } + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + Ok(Self { + agg_type: self.agg_type, + member_sql: self + .member_sql + .as_ref() + .map(|sql| sql.apply_recursive(f)) + .transpose()?, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + Box::new(self.member_sql.iter()) + } + + pub fn is_owned_by_cube(&self) -> bool { + self.member_sql + .as_ref() + .is_some_and(|sql| sql.is_owned_by_cube()) + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/calculated.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/calculated.rs new file mode 100644 index 0000000000000..2db26075aff48 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/calculated.rs @@ -0,0 +1,120 @@ +use super::super::super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CalculatedMeasureType { + Number, + String, + Time, + Boolean, +} + +impl CalculatedMeasureType { + pub fn as_str(&self) -> &'static str { + match self { + Self::Number => "number", + Self::String => "string", + Self::Time => "time", + Self::Boolean => "boolean", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "number" => Some(Self::Number), + "string" => Some(Self::String), + "time" => Some(Self::Time), + "boolean" => Some(Self::Boolean), + _ => None, + } + } +} + +#[derive(Clone)] +pub struct CalculatedMeasure { + calc_type: CalculatedMeasureType, + member_sql: Option>, +} + +impl CalculatedMeasure { + pub fn new(calc_type: CalculatedMeasureType, member_sql: Rc) -> Self { + Self { + calc_type, + member_sql: Some(member_sql), + } + } + + pub fn new_without_sql(calc_type: CalculatedMeasureType) -> Self { + Self { + calc_type, + member_sql: None, + } + } + + pub fn calc_type(&self) -> CalculatedMeasureType { + self.calc_type + } + + pub fn member_sql(&self) -> Option<&Rc> { + self.member_sql.as_ref() + } + + pub fn evaluate_sql( + &self, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + match &self.member_sql { + Some(sql) => sql.eval(visitor, node_processor, query_tools, templates), + None => Err(CubeError::internal( + "Calculated measure without sql cannot be evaluated directly".to_string(), + )), + } + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + if let Some(sql) = &self.member_sql { + sql.extract_symbol_deps(&mut deps); + } + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + if let Some(sql) = &self.member_sql { + sql.extract_symbol_deps_with_path(&mut deps); + } + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + Ok(Self { + calc_type: self.calc_type, + member_sql: self + .member_sql + .as_ref() + .map(|sql| sql.apply_recursive(f)) + .transpose()?, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + Box::new(self.member_sql.iter()) + } + + pub fn is_owned_by_cube(&self) -> bool { + self.member_sql + .as_ref() + .is_some_and(|sql| sql.is_owned_by_cube()) + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/count.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/count.rs new file mode 100644 index 0000000000000..85052222b57cf --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/count.rs @@ -0,0 +1,115 @@ +use super::super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +#[derive(Clone)] +pub enum CountSql { + Auto(Vec>), + Explicit(Rc), +} + +#[derive(Clone)] +pub struct CountMeasure { + sql: CountSql, +} + +impl CountMeasure { + pub fn new(sql: CountSql) -> Self { + Self { sql } + } + + pub fn sql(&self) -> &CountSql { + &self.sql + } + + pub fn evaluate_sql( + &self, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + match &self.sql { + CountSql::Explicit(sql) => sql.eval(visitor, node_processor, query_tools, templates), + CountSql::Auto(pk_sqls) => { + if pk_sqls.len() > 1 { + let pk_strings = pk_sqls + .iter() + .map(|pk| -> Result<_, CubeError> { + let res = pk.eval( + visitor, + node_processor.clone(), + query_tools.clone(), + templates, + )?; + templates.cast_to_string(&res) + }) + .collect::, _>>()?; + templates.concat_strings(&pk_strings) + } else if pk_sqls.len() == 1 { + let pk_sql = pk_sqls.first().unwrap(); + pk_sql.eval(visitor, node_processor, query_tools, templates) + } else { + Ok("*".to_string()) + } + } + } + } + + pub fn get_dependencies(&self) -> Vec> { + let mut deps = vec![]; + match &self.sql { + CountSql::Explicit(sql) => sql.extract_symbol_deps(&mut deps), + CountSql::Auto(pk_sqls) => { + for pk in pk_sqls { + pk.extract_symbol_deps(&mut deps); + } + } + } + deps + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + let mut deps = vec![]; + match &self.sql { + CountSql::Explicit(sql) => sql.extract_symbol_deps_with_path(&mut deps), + CountSql::Auto(pk_sqls) => { + for pk in pk_sqls { + pk.extract_symbol_deps_with_path(&mut deps); + } + } + } + deps + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + let sql = match &self.sql { + CountSql::Explicit(sql) => CountSql::Explicit(sql.apply_recursive(f)?), + CountSql::Auto(pk_sqls) => { + let new_pks = pk_sqls + .iter() + .map(|pk| pk.apply_recursive(f)) + .collect::, _>>()?; + CountSql::Auto(new_pks) + } + }; + Ok(Self { sql }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + match &self.sql { + CountSql::Explicit(sql) => Box::new(std::iter::once(sql)), + CountSql::Auto(pk_sqls) => Box::new(pk_sqls.iter()), + } + } + + pub fn is_owned_by_cube(&self) -> bool { + matches!(self.sql, CountSql::Auto(_)) + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/mod.rs new file mode 100644 index 0000000000000..91903b045313c --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_kinds/mod.rs @@ -0,0 +1,248 @@ +mod aggregated; +mod calculated; +mod count; + +pub use aggregated::*; +pub use calculated::*; +pub use count::*; + +use super::common::AggregationType; +use super::MemberSymbol; +use crate::planner::query_tools::QueryTools; +use crate::planner::sql_evaluator::{sql_nodes::SqlNode, SqlCall, SqlEvaluatorVisitor}; +use crate::planner::sql_templates::PlanSqlTemplates; +use cubenativeutils::CubeError; +use std::rc::Rc; + +pub enum AggregateWrap<'a> { + PassThrough, + Function(&'a str), + CountDistinct, + CountDistinctApprox, +} + +#[derive(Clone)] +pub enum MeasureKind { + Count(CountMeasure), + Aggregated(AggregatedMeasure), + Calculated(CalculatedMeasure), + Rank, +} + +impl MeasureKind { + pub fn from_type_str( + measure_type: &str, + member_sql: Option>, + pk_sqls: Vec>, + ) -> Result { + if measure_type == "count" { + Ok(match member_sql { + Some(sql) => Self::Count(CountMeasure::new(CountSql::Explicit(sql))), + None => Self::Count(CountMeasure::new(CountSql::Auto(pk_sqls))), + }) + } else if measure_type == "rank" { + Ok(Self::Rank) + } else if let Some(calc_type) = CalculatedMeasureType::from_str(measure_type) { + Ok(if let Some(sql) = member_sql { + Self::Calculated(CalculatedMeasure::new(calc_type, sql)) + } else { + Self::Calculated(CalculatedMeasure::new_without_sql(calc_type)) + }) + } else if let Ok(agg_type) = AggregationType::from_str(measure_type) { + Ok(if let Some(sql) = member_sql { + Self::Aggregated(AggregatedMeasure::new(agg_type, sql)) + } else { + Self::Aggregated(AggregatedMeasure::new_without_sql(agg_type)) + }) + } else { + Err(CubeError::user(format!( + "Unknown measure type: '{}'", + measure_type + ))) + } + } + + pub fn evaluate_sql( + &self, + full_name: &str, + visitor: &SqlEvaluatorVisitor, + node_processor: Rc, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result { + match self { + Self::Count(c) => c.evaluate_sql(visitor, node_processor, query_tools, templates), + Self::Aggregated(a) => a.evaluate_sql(visitor, node_processor, query_tools, templates), + Self::Calculated(c) => c.evaluate_sql(visitor, node_processor, query_tools, templates), + Self::Rank => Err(CubeError::internal(format!( + "Rank measure doesn't support direct evaluation for {}", + full_name + ))), + } + } + + pub fn get_dependencies(&self) -> Vec> { + match self { + Self::Count(c) => c.get_dependencies(), + Self::Aggregated(a) => a.get_dependencies(), + Self::Calculated(c) => c.get_dependencies(), + Self::Rank => vec![], + } + } + + pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { + match self { + Self::Count(c) => c.get_dependencies_with_path(), + Self::Aggregated(a) => a.get_dependencies_with_path(), + Self::Calculated(c) => c.get_dependencies_with_path(), + Self::Rank => vec![], + } + } + + pub fn apply_to_deps) -> Result, CubeError>>( + &self, + f: &F, + ) -> Result { + Ok(match self { + Self::Count(c) => Self::Count(c.apply_to_deps(f)?), + Self::Aggregated(a) => Self::Aggregated(a.apply_to_deps(f)?), + Self::Calculated(c) => Self::Calculated(c.apply_to_deps(f)?), + Self::Rank => Self::Rank, + }) + } + + pub fn iter_sql_calls(&self) -> Box> + '_> { + match self { + Self::Count(c) => c.iter_sql_calls(), + Self::Aggregated(a) => a.iter_sql_calls(), + Self::Calculated(c) => c.iter_sql_calls(), + Self::Rank => Box::new(std::iter::empty()), + } + } + + pub fn is_owned_by_cube(&self) -> bool { + match self { + Self::Count(c) => c.is_owned_by_cube(), + Self::Aggregated(a) => a.is_owned_by_cube(), + Self::Calculated(c) => c.is_owned_by_cube(), + Self::Rank => false, + } + } + + pub fn is_calculated(&self) -> bool { + matches!(self, Self::Calculated(_)) + } + + pub fn is_additive(&self) -> bool { + match self { + Self::Count(_) => true, + Self::Aggregated(a) => a.agg_type().is_additive(), + _ => false, + } + } + + pub fn measure_type_str(&self) -> &str { + match self { + Self::Count(_) => "count", + Self::Aggregated(a) => a.agg_type().as_str(), + Self::Calculated(c) => c.calc_type().as_str(), + Self::Rank => "rank", + } + } + + pub fn can_replace_type_with(&self, new_type: &str) -> bool { + match self { + Self::Aggregated(a) => { + let target_ok = matches!( + new_type, + "sum" | "avg" | "min" | "max" | "count_distinct" | "count_distinct_approx" + ); + match a.agg_type() { + AggregationType::Sum + | AggregationType::Avg + | AggregationType::Min + | AggregationType::Max => target_ok, + AggregationType::CountDistinct | AggregationType::CountDistinctApprox => { + matches!(new_type, "count_distinct" | "count_distinct_approx") + } + _ => false, + } + } + _ => false, + } + } + + pub fn supports_additional_filters(&self) -> bool { + match self { + Self::Count(_) => true, + Self::Aggregated(a) => matches!( + a.agg_type(), + AggregationType::Sum + | AggregationType::Avg + | AggregationType::Min + | AggregationType::Max + | AggregationType::CountDistinct + | AggregationType::CountDistinctApprox + ), + _ => false, + } + } + + pub fn member_sql(&self) -> Option<&Rc> { + match self { + Self::Count(c) => match c.sql() { + CountSql::Explicit(sql) => Some(sql), + CountSql::Auto(_) => None, + }, + Self::Aggregated(a) => a.member_sql(), + Self::Calculated(c) => c.member_sql(), + Self::Rank => None, + } + } + + pub fn aggregate_wrap(&self, is_multiplied: bool) -> AggregateWrap<'_> { + match self { + Self::Calculated(_) => AggregateWrap::PassThrough, + Self::Aggregated(a) => match a.agg_type() { + AggregationType::NumberAgg => AggregateWrap::PassThrough, + AggregationType::CountDistinctApprox => AggregateWrap::CountDistinctApprox, + AggregationType::CountDistinct => AggregateWrap::CountDistinct, + AggregationType::RunningTotal => AggregateWrap::Function("sum"), + _ => AggregateWrap::Function(a.agg_type().as_str()), + }, + Self::Count(_) => { + if is_multiplied { + AggregateWrap::CountDistinct + } else { + AggregateWrap::Function("count") + } + } + Self::Rank => AggregateWrap::PassThrough, + } + } + + pub fn pre_aggregate_wrap(&self) -> AggregateWrap<'_> { + match self { + Self::Count(_) => AggregateWrap::Function("sum"), + Self::Aggregated(a) => match a.agg_type() { + AggregationType::CountDistinctApprox => AggregateWrap::CountDistinctApprox, + AggregationType::Min => AggregateWrap::Function("min"), + AggregationType::Max => AggregateWrap::Function("max"), + _ => AggregateWrap::Function("sum"), + }, + _ => AggregateWrap::Function("sum"), + } + } + + pub fn with_new_type(&self, new_type: &str) -> Result { + let member_sql = self.member_sql().cloned(); + let pk_sqls = match self { + Self::Count(c) => match c.sql() { + CountSql::Auto(pks) => pks.clone(), + _ => vec![], + }, + _ => vec![], + }; + Self::from_type_str(new_type, member_sql, pk_sqls) + } +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_symbol.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_symbol.rs index 2a63037de71f6..7010d4b6bb0da 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_symbol.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/measure_symbol.rs @@ -1,4 +1,5 @@ -use super::common::Case; +use super::common::{AggregationType, Case}; +use super::measure_kinds::{CalculatedMeasure, CalculatedMeasureType, MeasureKind}; use super::SymbolPath; use super::{MemberSymbol, SymbolFactory}; use crate::cube_bridge::evaluator::CubeEvaluator; @@ -72,8 +73,7 @@ pub struct MeasureSymbol { cube: Rc, name: String, alias: String, - owned_by_cube: bool, - measure_type: String, + kind: MeasureKind, rolling_window: Option, is_multi_stage: bool, is_reference: bool, @@ -86,8 +86,6 @@ pub struct MeasureSymbol { reduce_by: Option>>, add_group_by: Option>>, group_by: Option>>, - member_sql: Option>, - pk_sqls: Vec>, is_splitted_source: bool, } @@ -96,13 +94,12 @@ impl MeasureSymbol { cube: Rc, name: String, alias: String, - member_sql: Option>, is_reference: bool, is_view: bool, - owned_by_cube: bool, case: Option, - pk_sqls: Vec>, - definition: Rc, + kind: MeasureKind, + rolling_window: Option, + is_multi_stage: bool, measure_filters: Vec>, measure_drill_filters: Vec>, time_shift: Option, @@ -111,20 +108,14 @@ impl MeasureSymbol { add_group_by: Option>>, group_by: Option>>, ) -> Rc { - let measure_type = definition.static_data().measure_type.clone(); - let rolling_window = definition.static_data().rolling_window.clone(); - let is_multi_stage = definition.static_data().multi_stage.unwrap_or(false); Rc::new(Self { cube, name, alias, - member_sql, is_reference, is_view, case, - pk_sqls, - owned_by_cube, - measure_type, + kind, rolling_window, measure_filters, measure_drill_filters, @@ -140,17 +131,25 @@ impl MeasureSymbol { pub fn new_unrolling(&self) -> Rc { if self.is_rolling_window() { - let measure_type = if self.is_multi_stage { - format!("number") + let kind = if self.is_multi_stage { + if let Some(sql) = self.kind.member_sql() { + MeasureKind::Calculated(CalculatedMeasure::new( + CalculatedMeasureType::Number, + sql.clone(), + )) + } else { + MeasureKind::Calculated(CalculatedMeasure::new_without_sql( + CalculatedMeasureType::Number, + )) + } } else { - self.measure_type.clone() + self.kind.clone() }; Rc::new(Self { cube: self.cube.clone(), name: self.name.clone(), alias: self.alias.clone(), - owned_by_cube: self.owned_by_cube, - measure_type, + kind, rolling_window: None, is_multi_stage: false, is_reference: false, @@ -163,8 +162,6 @@ impl MeasureSymbol { reduce_by: self.reduce_by.clone(), add_group_by: self.add_group_by.clone(), group_by: self.group_by.clone(), - member_sql: self.member_sql.clone(), - pk_sqls: self.pk_sqls.clone(), is_splitted_source: self.is_splitted_source, }) } else { @@ -177,64 +174,36 @@ impl MeasureSymbol { new_measure_type: Option, add_filters: Vec>, ) -> Result, CubeError> { - let result_measure_type = if let Some(new_measure_type) = new_measure_type { - match self.measure_type.as_str() { - "sum" | "avg" | "min" | "max" => match new_measure_type.as_str() { - "sum" | "avg" | "min" | "max" | "count_distinct" | "count_distinct_approx" => {} - _ => { - return Err(CubeError::user(format!( - "Unsupported measure type replacement for {}: {} => {}", - self.name, self.measure_type, new_measure_type - ))) - } - }, - "count_distinct" | "count_distinct_approx" => match new_measure_type.as_str() { - "count_distinct" | "count_distinct_approx" => {} - _ => { - return Err(CubeError::user(format!( - "Unsupported measure type replacement for {}: {} => {}", - self.name, self.measure_type, new_measure_type - ))) - } - }, - - _ => { - return Err(CubeError::user(format!( - "Unsupported measure type replacement for {}: {} => {}", - self.name, self.measure_type, new_measure_type - ))) - } + let result_kind = if let Some(new_measure_type) = new_measure_type { + if !self.kind.can_replace_type_with(&new_measure_type) { + return Err(CubeError::user(format!( + "Unsupported measure type replacement for {}: {} => {}", + self.name, + self.kind.measure_type_str(), + new_measure_type + ))); } - new_measure_type + self.kind.with_new_type(&new_measure_type)? } else { - self.measure_type.clone() + self.kind.clone() }; let mut measure_filters = self.measure_filters.clone(); if !add_filters.is_empty() { - match result_measure_type.as_str() { - "sum" - | "avg" - | "min" - | "max" - | "count" - | "count_distinct" - | "count_distinct_approx" => {} - _ => { - return Err(CubeError::user(format!( - "Unsupported additional filters for measure {} type {}", - self.name, result_measure_type - ))) - } + if !result_kind.supports_additional_filters() { + return Err(CubeError::user(format!( + "Unsupported additional filters for measure {} type {}", + self.name, + result_kind.measure_type_str() + ))); } - measure_filters.extend(add_filters.into_iter()); + measure_filters.extend(add_filters); } Ok(Rc::new(Self { cube: self.cube.clone(), name: self.name.clone(), alias: self.alias.clone(), - owned_by_cube: self.owned_by_cube, - measure_type: result_measure_type, + kind: result_kind, rolling_window: self.rolling_window.clone(), is_multi_stage: self.is_multi_stage, is_reference: self.is_reference, @@ -247,8 +216,6 @@ impl MeasureSymbol { reduce_by: self.reduce_by.clone(), add_group_by: self.add_group_by.clone(), group_by: self.group_by.clone(), - member_sql: self.member_sql.clone(), - pk_sqls: self.pk_sqls.clone(), is_splitted_source: self.is_splitted_source, })) } @@ -271,23 +238,12 @@ impl MeasureSymbol { self.is_splitted_source } - pub fn pk_sqls(&self) -> &Vec> { - &self.pk_sqls - } - pub fn time_shift(&self) -> &Option { &self.time_shift } pub fn is_calculated(&self) -> bool { - Self::is_calculated_type(&self.measure_type) - } - - pub fn is_calculated_type(measure_type: &str) -> bool { - match measure_type { - "number" | "string" | "time" | "boolean" => true, - _ => false, - } + self.kind.is_calculated() } pub fn case(&self) -> Option<&Case> { @@ -298,10 +254,7 @@ impl MeasureSymbol { if self.is_multi_stage() { false } else { - match self.measure_type().as_str() { - "sum" | "count" | "countDistinctApprox" | "min" | "max" => true, - _ => false, - } + self.kind.is_additive() } } @@ -312,19 +265,13 @@ impl MeasureSymbol { query_tools: Rc, templates: &PlanSqlTemplates, ) -> Result { - if let Some(member_sql) = &self.member_sql { - let sql = member_sql.eval(visitor, node_processor, query_tools, templates)?; - Ok(sql) - } else { - Err(CubeError::internal(format!( - "Measure {} hasn't sql evaluator", - self.full_name() - ))) - } - } - - pub fn has_sql(&self) -> bool { - self.member_sql.is_some() + self.kind.evaluate_sql( + &self.full_name(), + visitor, + node_processor, + query_tools, + templates, + ) } pub fn apply_to_deps) -> Result, CubeError>>( @@ -332,13 +279,7 @@ impl MeasureSymbol { f: &F, ) -> Result, CubeError> { let mut result = self.clone(); - if let Some(member_sql) = &self.member_sql { - result.member_sql = Some(member_sql.apply_recursive(f)?); - } - - for sql in result.pk_sqls.iter_mut() { - *sql = sql.apply_recursive(f)? - } + result.kind = result.kind.apply_to_deps(f)?; for sql in result.measure_filters.iter_mut() { *sql = sql.apply_recursive(f)? @@ -363,20 +304,14 @@ impl MeasureSymbol { //FIXME We don't include filters and order_by here for backward compatibility // because BaseQuery doesn't validate these SQL calls let result = self - .member_sql - .iter() + .kind + .iter_sql_calls() .chain(self.case.iter().flat_map(|case| case.iter_sql_calls())); Box::new(result) } pub fn get_dependencies(&self) -> Vec> { - let mut deps = vec![]; - if let Some(member_sql) = &self.member_sql { - member_sql.extract_symbol_deps(&mut deps); - } - for pk in self.pk_sqls.iter() { - pk.extract_symbol_deps(&mut deps); - } + let mut deps = self.kind.get_dependencies(); for filter in self.measure_filters.iter() { filter.extract_symbol_deps(&mut deps); } @@ -393,13 +328,7 @@ impl MeasureSymbol { } pub fn get_dependencies_with_path(&self) -> Vec<(Rc, Vec)> { - let mut deps = vec![]; - if let Some(member_sql) = &self.member_sql { - member_sql.extract_symbol_deps_with_path(&mut deps); - } - for pk in self.pk_sqls.iter() { - pk.extract_symbol_deps_with_path(&mut deps); - } + let mut deps = self.kind.get_dependencies_with_path(); for filter in self.measure_filters.iter() { filter.extract_symbol_deps_with_path(&mut deps); } @@ -416,17 +345,28 @@ impl MeasureSymbol { } pub fn can_used_as_addictive_in_multplied(&self) -> bool { - if &self.measure_type == "countDistinct" || &self.measure_type == "countDistinctApprox" { - true - } else if &self.measure_type == "count" && self.member_sql.is_none() { - true - } else { - false + match &self.kind { + MeasureKind::Aggregated(agg) => agg.agg_type().is_distinct(), + MeasureKind::Count(count) => count.is_owned_by_cube(), + _ => false, } } pub fn owned_by_cube(&self) -> bool { - self.owned_by_cube + if self.is_multi_stage { + return false; + } + let mut owned = self.kind.is_owned_by_cube(); + for sql in &self.measure_filters { + owned |= sql.is_owned_by_cube(); + } + for sql in &self.measure_drill_filters { + owned |= sql.is_owned_by_cube(); + } + if let Some(case) = &self.case { + owned |= case.is_owned_by_cube(); + } + owned } pub fn is_reference(&self) -> bool { @@ -448,8 +388,12 @@ impl MeasureSymbol { deps.first().cloned() } - pub fn measure_type(&self) -> &String { - &self.measure_type + pub fn measure_type(&self) -> &str { + self.kind.measure_type_str() + } + + pub fn kind(&self) -> &MeasureKind { + &self.kind } pub fn rolling_window(&self) -> &Option { @@ -461,7 +405,7 @@ impl MeasureSymbol { } pub fn is_running_total(&self) -> bool { - self.measure_type() == "runningTotal" + matches!(&self.kind, MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::RunningTotal) } pub fn is_cumulative(&self) -> bool { @@ -596,11 +540,7 @@ impl SymbolFactory for MeasureSymbolFactory { None }; - let is_sql_is_direct_ref = if let Some(sql) = &sql { - sql.is_direct_reference() - } else { - false - }; + let is_sql_is_direct_ref = sql.as_ref().is_some_and(|s| s.is_direct_reference()); let time_shifts = if let Some(time_shift_references) = &definition.static_data().time_shift_references @@ -723,20 +663,17 @@ impl SymbolFactory for MeasureSymbolFactory { None }; - let measure_type = &definition.static_data().measure_type; - let is_calculated = MeasureSymbol::is_calculated_type(&measure_type) - && !definition.static_data().multi_stage.unwrap_or(false); - + let measure_type_str = &definition.static_data().measure_type; + let rolling_window = definition.static_data().rolling_window.clone(); let is_multi_stage = definition.static_data().multi_stage.unwrap_or(false); + + let kind = MeasureKind::from_type_str(measure_type_str, sql, pk_sqls)?; + let is_calculated = kind.is_calculated() && !is_multi_stage; + let owned_by_cube = if is_multi_stage { false - } else if measure_type == "count" && sql.is_none() { - true } else { - let mut owned = false; - if let Some(sql) = &sql { - owned |= sql.is_owned_by_cube(); - } + let mut owned = kind.is_owned_by_cube(); for sql in &measure_filters { owned |= sql.is_owned_by_cube(); } @@ -780,13 +717,12 @@ impl SymbolFactory for MeasureSymbolFactory { cube_symbol, path.symbol_name().clone(), alias, - sql, is_reference, is_view, - owned_by_cube, case, - pk_sqls, - definition, + kind, + rolling_window, + is_multi_stage, measure_filters, measure_drill_filters, time_shifts, @@ -806,7 +742,7 @@ impl crate::utils::debug::DebugSql for MeasureSymbol { } // Get base SQL - let base_sql = if let Some(sql) = &self.member_sql { + let base_sql = if let Some(sql) = self.kind.member_sql() { sql.debug_sql(expand_deps) } else { "".to_string() diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/mod.rs index 84d02f0d1997c..c7ec09111a11b 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/mod.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/mod.rs @@ -1,6 +1,8 @@ mod common; mod cube_symbol; +pub mod dimension_kinds; mod dimension_symbol; +pub mod measure_kinds; mod measure_symbol; mod member_expression_symbol; mod member_symbol; @@ -11,7 +13,12 @@ pub use common::*; pub use cube_symbol::{ CubeNameSymbol, CubeNameSymbolFactory, CubeTableSymbol, CubeTableSymbolFactory, }; +pub use dimension_kinds::DimensionKind; pub use dimension_symbol::*; +pub use measure_kinds::{ + AggregateWrap, AggregatedMeasure, CalculatedMeasure, CalculatedMeasureType, CountMeasure, + CountSql, MeasureKind, +}; pub use measure_symbol::{ DimensionTimeShift, MeasureSymbol, MeasureSymbolFactory, MeasureTimeShifts, }; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/time_dimension_symbol.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/time_dimension_symbol.rs index a36c6726134d8..2e9347065a8ab 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/time_dimension_symbol.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/symbols/time_dimension_symbol.rs @@ -118,7 +118,7 @@ impl TimeDimensionSymbol { .into_iter() .map(|s| match s.as_ref() { MemberSymbol::Dimension(dimension_symbol) => { - if dimension_symbol.dimension_type() == "time" { + if dimension_symbol.is_time() { let result = Self::new( s.clone(), self.granularity.clone(), diff --git a/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/calc_groups.yaml b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/calc_groups.yaml new file mode 100644 index 0000000000000..7bbcf9bddcabe --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/calc_groups.yaml @@ -0,0 +1,38 @@ +cubes: + - name: orders + sql: "SELECT * FROM orders" + dimensions: + - name: id + type: number + sql: id + primary_key: true + - name: date + type: time + sql: date + - name: currency + type: string + sql: currency + measures: + - name: amount_usd + type: sum + sql: amount_usd + - name: amount_eur + type: sum + sql: amount_eur + - name: amount_gbp + type: sum + sql: amount_gbp + - name: amount_in_currency + type: number + multi_stage: true + case: + switch: "{CUBE.currency}" + when: + - value: "USD" + sql: "{CUBE.amount_usd}" + - value: "EUR" + sql: "{CUBE.amount_eur}" + - value: "GBP" + sql: "{CUBE.amount_gbp}" + else: + sql: "{CUBE.amount_usd}" diff --git a/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/dimension_kind_tests.yaml b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/dimension_kind_tests.yaml new file mode 100644 index 0000000000000..ba03beadd6667 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/dimension_kind_tests.yaml @@ -0,0 +1,68 @@ +cubes: + - name: aux + sql: "SELECT * FROM aux" + dimensions: + - name: id + type: number + sql: id + primary_key: true + - name: min_date + type: time + sql: "MIN(created_at)" + + - name: test_dims + sql: "SELECT * FROM test_dims" + joins: + - name: aux + relationship: many_to_one + sql: "{CUBE}.aux_id = {aux}.id" + dimensions: + - name: id + type: number + sql: id + primary_key: true + - name: name + type: string + sql: "{CUBE}.name" + - name: amount + type: number + sql: "{CUBE}.amount" + - name: created_at + type: time + sql: created_at + - name: is_active + type: boolean + sql: "{CUBE}.is_active" + - name: location + type: geo + latitude: "{CUBE}.lat" + longitude: "{CUBE}.lng" + - name: currency + type: switch + sql: "{CUBE}.currency" + values: + - USD + - EUR + - GBP + - name: calc_group + type: switch + values: + - option_a + - option_b + - name: status_label + type: string + case: + when: + - sql: "{CUBE}.status = 1" + label: Active + - sql: "{CUBE}.status = 2" + label: Inactive + else: + label: Unknown + - name: sub_query_dim + type: time + sql: "{aux.min_date}" + sub_query: true + measures: + - name: cnt + type: count diff --git a/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/measure_kind_tests.yaml b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/measure_kind_tests.yaml new file mode 100644 index 0000000000000..2fba7d430b418 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/common/measure_kind_tests.yaml @@ -0,0 +1,57 @@ +cubes: + - name: test_measures + sql: "SELECT * FROM test_measures" + dimensions: + - name: id + type: number + sql: id + primary_key: true + - name: created_at + type: time + sql: created_at + - name: status + type: string + sql: "{CUBE}.status" + measures: + - name: total + type: sum + sql: amount + - name: average + type: avg + sql: amount + - name: minimum + type: min + sql: amount + - name: maximum + type: max + sql: amount + - name: cnt + type: count + - name: distinct_count + type: countDistinct + sql: user_id + - name: approx_count + type: countDistinctApprox + sql: user_id + - name: running + type: runningTotal + sql: amount + - name: number_agg + type: numberAgg + sql: amount + - name: calculated + type: number + sql: "{CUBE.total} / {CUBE.cnt}" + - name: rank_measure + type: rank + - name: rolling_sum + type: sum + sql: amount + rolling_window: + trailing: "7 day" + offset: start + - name: filtered_total + type: sum + sql: amount + filters: + - sql: "{CUBE}.status = 'active'" diff --git a/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/symbol_evaluator/measure_types.yaml b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/symbol_evaluator/measure_types.yaml new file mode 100644 index 0000000000000..fc2c6d14c11bb --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/test_fixtures/schemas/yaml_files/symbol_evaluator/measure_types.yaml @@ -0,0 +1,33 @@ +cubes: + - name: test_cube + sql: "SELECT 1" + dimensions: + - name: id + type: number + sql: id + primary_key: true + - name: source + type: string + sql: "{CUBE}.source" + - name: created_at + type: time + sql: created_at + - name: status + type: string + sql: "{CUBE}.status" + measures: + - name: sum_revenue + type: sum + sql: revenue + - name: string_status + type: string + sql: "{CUBE.source}" + - name: time_last_activity + type: time + sql: "{CUBE.created_at}" + - name: boolean_has_revenue + type: boolean + sql: "{CUBE.sum_revenue} > 0" + - name: number_agg_metric + type: numberAgg + sql: "{CUBE.sum_revenue} * 100" diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/common_sql_generation.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/common_sql_generation.rs index ec214f5bcbed5..b9f8fc7dcfa88 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/tests/common_sql_generation.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/common_sql_generation.rs @@ -155,3 +155,36 @@ fn test_segment_as_dimension_in_pre_aggregation_query() { insta::assert_snapshot!(sql); } + +#[test] +fn test_measure_switch_cross_join() { + let schema = MockSchema::from_yaml_file("common/calc_groups.yaml"); + let test_context = TestContext::new(schema).unwrap(); + + let query_yaml = indoc! {" + dimensions: + - orders.currency + measures: + - orders.amount_usd + - orders.amount_in_currency + time_dimensions: + - dimension: orders.date + granularity: year + dateRange: + - \"2024-01-01\" + - \"2026-01-01\" + "}; + + let sql = test_context + .build_sql(query_yaml) + .expect("Should generate SQL for case-switch measure"); + + // amount_in_currency is type "number" — a calculated measure. + // It must NOT be wrapped in an aggregation function like number(...). + assert!( + !sql.contains("number("), + "Calculated measure must not be wrapped in number() aggregation" + ); + + insta::assert_snapshot!(sql); +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/compilation.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/compilation.rs index dadad8e0ad8b1..c57f9f1a09bb6 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/compilation.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/compilation.rs @@ -1,5 +1,8 @@ //! Tests for Compiler member evaluation +use crate::planner::sql_evaluator::symbols::dimension_kinds::DimensionKind; +use crate::planner::sql_evaluator::symbols::DimensionType; +use crate::planner::sql_evaluator::{AggregationType, CalculatedMeasureType, MeasureKind}; use crate::test_fixtures::cube_bridge::MockSchema; use crate::test_fixtures::schemas::TestCompiler; use crate::test_fixtures::test_utils::TestContext; @@ -21,7 +24,9 @@ fn test_add_dimension_evaluator_number_dimension() { assert_eq!(symbol.cube_name(), "visitors"); assert_eq!(symbol.name(), "id"); assert_eq!(symbol.get_dependencies().len(), 0); - assert_eq!(symbol.as_dimension().unwrap().dimension_type(), "number"); + assert!( + matches!(symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Number) + ); } #[test] @@ -41,7 +46,9 @@ fn test_add_dimension_evaluator_string_dimension() { assert_eq!(symbol.cube_name(), "visitors"); assert_eq!(symbol.name(), "source"); assert_eq!(symbol.get_dependencies().len(), 0); - assert_eq!(symbol.as_dimension().unwrap().dimension_type(), "string"); + assert!( + matches!(symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::String) + ); } #[test] @@ -95,16 +102,16 @@ fn test_add_dimension_evaluator_multiple_dimensions() { .unwrap(); assert_eq!(id_symbol.full_name(), "visitors.id"); - assert_eq!(id_symbol.as_dimension().unwrap().dimension_type(), "number"); + assert!( + matches!(id_symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Number) + ); assert_eq!(source_symbol.full_name(), "visitors.source"); - assert_eq!( - source_symbol.as_dimension().unwrap().dimension_type(), - "string" + assert!( + matches!(source_symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::String) ); assert_eq!(created_at_symbol.full_name(), "visitors.created_at"); - assert_eq!( - created_at_symbol.as_dimension().unwrap().dimension_type(), - "time" + assert!( + matches!(created_at_symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Time) ); assert_eq!(id_symbol.get_dependencies().len(), 0); assert_eq!(source_symbol.get_dependencies().len(), 0); @@ -128,7 +135,10 @@ fn test_add_measure_evaluator_count_measure() { assert_eq!(symbol.cube_name(), "visitor_checkins"); assert_eq!(symbol.name(), "count"); assert_eq!(symbol.get_dependencies().len(), 0); - assert_eq!(symbol.as_measure().unwrap().measure_type(), "count"); + assert!(matches!( + symbol.as_measure().unwrap().kind(), + MeasureKind::Count(_) + )); } #[test] @@ -148,7 +158,10 @@ fn test_add_measure_evaluator_sum_measure() { assert_eq!(symbol.cube_name(), "visitors"); assert_eq!(symbol.name(), "total_revenue"); assert_eq!(symbol.get_dependencies().len(), 0); - assert_eq!(symbol.as_measure().unwrap().measure_type(), "sum"); + assert!(matches!( + symbol.as_measure().unwrap().kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); } #[test] @@ -198,9 +211,15 @@ fn test_add_measure_evaluator_multiple_measures() { .unwrap(); assert_eq!(count_symbol.full_name(), "visitor_checkins.count"); - assert_eq!(count_symbol.as_measure().unwrap().measure_type(), "count"); + assert!(matches!( + count_symbol.as_measure().unwrap().kind(), + MeasureKind::Count(_) + )); assert_eq!(revenue_symbol.full_name(), "visitors.total_revenue"); - assert_eq!(revenue_symbol.as_measure().unwrap().measure_type(), "sum"); + assert!(matches!( + revenue_symbol.as_measure().unwrap().kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); assert_eq!(count_symbol.get_dependencies().len(), 0); assert_eq!(revenue_symbol.get_dependencies().len(), 0); } @@ -222,7 +241,9 @@ fn test_add_auto_resolved_member_evaluator_dimension() { assert_eq!(symbol.cube_name(), "visitors"); assert_eq!(symbol.name(), "source"); assert_eq!(symbol.get_dependencies().len(), 0); - assert_eq!(symbol.as_dimension().unwrap().dimension_type(), "string"); + assert!( + matches!(symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::String) + ); } #[test] @@ -242,7 +263,10 @@ fn test_add_auto_resolved_member_evaluator_measure() { assert_eq!(symbol.cube_name(), "visitors"); assert_eq!(symbol.name(), "total_revenue"); assert_eq!(symbol.get_dependencies().len(), 0); - assert_eq!(symbol.as_measure().unwrap().measure_type(), "sum"); + assert!(matches!( + symbol.as_measure().unwrap().kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); } #[test] @@ -295,7 +319,9 @@ fn test_dimension_with_cube_table_dependency() { assert!(symbol.is_dimension()); assert_eq!(symbol.full_name(), "visitors.visitor_id"); assert_eq!(symbol.cube_name(), "visitors"); - assert_eq!(symbol.as_dimension().unwrap().dimension_type(), "number"); + assert!( + matches!(symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Number) + ); let dependencies = symbol.get_dependencies(); assert_eq!(dependencies.len(), 1, "Should have 1 dependency on CUBE"); @@ -320,7 +346,9 @@ fn test_dimension_with_member_dependency_no_prefix() { assert!(symbol.is_dimension()); assert_eq!(symbol.full_name(), "visitors.visitor_id_twice"); assert_eq!(symbol.cube_name(), "visitors"); - assert_eq!(symbol.as_dimension().unwrap().dimension_type(), "number"); + assert!( + matches!(symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Number) + ); let dependencies = symbol.get_dependencies(); assert_eq!( @@ -349,7 +377,9 @@ fn test_dimension_with_mixed_dependencies() { assert!(symbol.is_dimension()); assert_eq!(symbol.full_name(), "visitors.source_concat_id"); assert_eq!(symbol.cube_name(), "visitors"); - assert_eq!(symbol.as_dimension().unwrap().dimension_type(), "string"); + assert!( + matches!(symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::String) + ); let dependencies = symbol.get_dependencies(); assert_eq!( @@ -388,7 +418,10 @@ fn test_measure_with_cube_table_dependency() { assert!(symbol.is_measure()); assert_eq!(symbol.full_name(), "visitors.revenue"); assert_eq!(symbol.cube_name(), "visitors"); - assert_eq!(symbol.as_measure().unwrap().measure_type(), "sum"); + assert!(matches!( + symbol.as_measure().unwrap().kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); let dependencies = symbol.get_dependencies(); assert_eq!(dependencies.len(), 1, "Should have 1 dependency on CUBE"); @@ -413,7 +446,10 @@ fn test_measure_with_explicit_cube_and_member_dependencies() { assert!(symbol.is_measure()); assert_eq!(symbol.full_name(), "visitors.total_revenue_per_count"); assert_eq!(symbol.cube_name(), "visitors"); - assert_eq!(symbol.as_measure().unwrap().measure_type(), "number"); + assert!(matches!( + symbol.as_measure().unwrap().kind(), + MeasureKind::Calculated(c) if c.calc_type() == CalculatedMeasureType::Number + )); let dependencies = symbol.get_dependencies(); assert_eq!(dependencies.len(), 2, "Should have 2 measure dependencies"); @@ -667,9 +703,8 @@ fn test_time_dimension_with_granularity_compilation() { "visitors.created_at", "Base symbol should be visitors.created_at" ); - assert_eq!( - base_symbol.as_dimension().unwrap().dimension_type(), - "time", + assert!( + matches!(base_symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Time), "Base dimension should be time type" ); } @@ -721,9 +756,8 @@ fn test_sql_deps_validation() { "visitors.created_at", "Base symbol should be visitors.created_at" ); - assert_eq!( - base_symbol.as_dimension().unwrap().dimension_type(), - "time", + assert!( + matches!(base_symbol.as_dimension().unwrap().kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Time), "Base dimension should be time type" ); } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/symbol_evaluator.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/symbol_evaluator.rs index 5e27038299607..d1e82887c11a6 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/symbol_evaluator.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/cube_evaluator/symbol_evaluator.rs @@ -136,3 +136,49 @@ fn composite_symbols() { r#"sum("test_cube".revenue) + avg("test_cube".revenue)/min("test_cube".revenue) - min("test_cube".revenue)"# ); } + +#[test] +fn string_measure() { + let schema = MockSchema::from_yaml_file("symbol_evaluator/measure_types.yaml"); + let context = TestContext::new(schema).unwrap(); + + let symbol = context.create_measure("test_cube.string_status").unwrap(); + let sql = context.evaluate_symbol(&symbol).unwrap(); + assert_eq!(sql, r#""test_cube".source"#); +} + +#[test] +fn time_measure() { + let schema = MockSchema::from_yaml_file("symbol_evaluator/measure_types.yaml"); + let context = TestContext::new(schema).unwrap(); + + let symbol = context + .create_measure("test_cube.time_last_activity") + .unwrap(); + let sql = context.evaluate_symbol(&symbol).unwrap(); + assert_eq!(sql, r#""test_cube".created_at"#); +} + +#[test] +fn boolean_measure() { + let schema = MockSchema::from_yaml_file("symbol_evaluator/measure_types.yaml"); + let context = TestContext::new(schema).unwrap(); + + let symbol = context + .create_measure("test_cube.boolean_has_revenue") + .unwrap(); + let sql = context.evaluate_symbol(&symbol).unwrap(); + assert_eq!(sql, r#"sum("test_cube".revenue) > 0"#); +} + +#[test] +fn number_agg_measure() { + let schema = MockSchema::from_yaml_file("symbol_evaluator/measure_types.yaml"); + let context = TestContext::new(schema).unwrap(); + + let symbol = context + .create_measure("test_cube.number_agg_metric") + .unwrap(); + let sql = context.evaluate_symbol(&symbol).unwrap(); + assert_eq!(sql, r#"sum("test_cube".revenue) * 100"#); +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/dimension_symbol.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/dimension_symbol.rs new file mode 100644 index 0000000000000..2c798d4ce16bf --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/dimension_symbol.rs @@ -0,0 +1,156 @@ +//! Tests for DimensionSymbol: kind classification and helper methods + +use crate::planner::sql_evaluator::symbols::dimension_kinds::DimensionKind; +use crate::planner::sql_evaluator::symbols::DimensionType; +use crate::test_fixtures::cube_bridge::MockSchema; +use crate::test_fixtures::test_utils::TestContext; + +fn ctx() -> TestContext { + let schema = MockSchema::from_yaml_file("common/dimension_kind_tests.yaml"); + TestContext::new(schema).unwrap() +} + +// ─── Per-dimension property tests ─────────────────────────────────────────── + +#[test] +fn dimension_regular_string() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.name").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!( + matches!(dim.kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::String) + ); + assert!(!dim.is_time()); + assert!(!dim.is_geo()); + assert!(!dim.is_switch()); + assert!(!dim.is_case()); + assert!(!dim.is_calc_group()); + assert!(!dim.is_sub_query()); + assert!(dim.member_sql().is_some()); + assert!(!matches!(dim.kind(), DimensionKind::Geo(_))); + assert!(dim.case().is_none()); + assert!(dim.values().is_empty()); +} + +#[test] +fn dimension_regular_number() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.amount").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!( + matches!(dim.kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Number) + ); + assert!(!dim.is_time()); + assert!(!dim.is_geo()); +} + +#[test] +fn dimension_regular_time() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.created_at").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!( + matches!(dim.kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Time) + ); + assert!(dim.is_time()); + assert!(!dim.is_geo()); + assert!(!dim.is_switch()); + assert!(!dim.is_case()); +} + +#[test] +fn dimension_regular_boolean() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.is_active").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!( + matches!(dim.kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Boolean) + ); + assert!(!dim.is_time()); +} + +#[test] +fn dimension_geo() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.location").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!(matches!(dim.kind(), DimensionKind::Geo(_))); + assert!(dim.is_geo()); + assert!(!dim.is_time()); + assert!(!dim.is_switch()); + assert!(!dim.is_case()); + assert!(!dim.is_calc_group()); + let geo = match dim.kind() { + DimensionKind::Geo(g) => g, + _ => panic!("Expected Geo kind"), + }; + // latitude and longitude are guaranteed to exist in Geo kind + let _ = geo.latitude(); + let _ = geo.longitude(); + assert!(dim.member_sql().is_none()); +} + +#[test] +fn dimension_switch_with_sql() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.currency").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!(matches!(dim.kind(), DimensionKind::Switch(_))); + assert!(dim.is_switch()); + assert!(!dim.is_calc_group()); + assert!(!dim.is_time()); + assert!(!dim.is_geo()); + assert!(!dim.is_case()); + assert_eq!(dim.values(), &["USD", "EUR", "GBP"]); + assert!(dim.member_sql().is_some()); +} + +#[test] +fn dimension_calc_group() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.calc_group").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!(matches!(dim.kind(), DimensionKind::Switch(_))); + assert!(dim.is_switch()); + assert!(dim.is_calc_group()); + assert_eq!(dim.values(), &["option_a", "option_b"]); + assert!(dim.member_sql().is_none()); + assert!(!dim.owned_by_cube()); +} + +#[test] +fn dimension_case() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.status_label").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!( + matches!(dim.kind(), DimensionKind::Case(c) if *c.dimension_type() == DimensionType::String) + ); + assert!(dim.is_case()); + assert!(!dim.is_time()); + assert!(!dim.is_geo()); + assert!(!dim.is_switch()); + assert!(!dim.is_calc_group()); + assert!(dim.case().is_some()); +} + +#[test] +fn dimension_sub_query() { + let ctx = ctx(); + let d = ctx.create_dimension("test_dims.sub_query_dim").unwrap(); + let dim = d.as_dimension().unwrap(); + + assert!( + matches!(dim.kind(), DimensionKind::Regular(r) if *r.dimension_type() == DimensionType::Time) + ); + assert!(dim.is_time()); + assert!(dim.is_sub_query()); +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/measure_symbol.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/measure_symbol.rs new file mode 100644 index 0000000000000..3065bfbb60ddd --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/measure_symbol.rs @@ -0,0 +1,471 @@ +//! Tests for MeasureSymbol: kind classification, new_patched, and helper methods + +use crate::planner::sql_evaluator::{AggregationType, CalculatedMeasureType, MeasureKind, SqlCall}; +use crate::test_fixtures::cube_bridge::MockSchema; +use crate::test_fixtures::test_utils::TestContext; +use std::rc::Rc; + +fn ctx() -> TestContext { + let schema = MockSchema::from_yaml_file("common/measure_kind_tests.yaml"); + TestContext::new(schema).unwrap() +} + +fn get_filter_calls(ctx: &TestContext) -> Vec> { + let symbol = ctx.create_measure("test_measures.filtered_total").unwrap(); + symbol.as_measure().unwrap().measure_filters().clone() +} + +// ─── Per-measure property tests ───────────────────────────────────────────── + +#[test] +fn measure_count_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.cnt").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!(measure.kind(), MeasureKind::Count(_))); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_rolling_window()); + assert!(!measure.is_cumulative()); + assert!(measure.is_addictive()); +} + +#[test] +fn measure_sum_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.total").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_rolling_window()); + assert!(!measure.is_cumulative()); + assert!(measure.is_addictive()); +} + +#[test] +fn measure_avg_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.average").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Avg + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(!measure.is_addictive()); +} + +#[test] +fn measure_min_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.minimum").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Min + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(measure.is_addictive()); +} + +#[test] +fn measure_max_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.maximum").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Max + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(measure.is_addictive()); +} + +#[test] +fn measure_count_distinct_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.distinct_count").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::CountDistinct + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(!measure.is_addictive()); +} + +#[test] +fn measure_count_distinct_approx_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.approx_count").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::CountDistinctApprox + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(measure.is_addictive()); +} + +#[test] +fn measure_running_total_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.running").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::RunningTotal + )); + assert!(!measure.is_calculated()); + assert!(measure.is_running_total()); + assert!(!measure.is_rolling_window()); + assert!(measure.is_cumulative()); + assert!(measure.is_addictive()); +} + +#[test] +fn measure_number_agg_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.number_agg").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::NumberAgg + )); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(!measure.is_addictive()); +} + +#[test] +fn measure_calculated_number_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.calculated").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Calculated(c) if c.calc_type() == CalculatedMeasureType::Number + )); + assert!(measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(!measure.is_addictive()); +} + +#[test] +fn measure_rank_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.rank_measure").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!(measure.kind(), MeasureKind::Rank)); + assert!(!measure.is_calculated()); + assert!(!measure.is_running_total()); + assert!(!measure.is_cumulative()); + assert!(!measure.is_addictive()); +} + +#[test] +fn measure_rolling_window_properties() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.rolling_sum").unwrap(); + let measure = m.as_measure().unwrap(); + + assert!(matches!( + measure.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); + assert!(measure.is_rolling_window()); + assert!(measure.is_cumulative()); + assert!(!measure.is_running_total()); +} + +// ─── new_patched: valid type replacements ─────────────────────────────────── + +#[test] +fn new_patched_sum_to_all_valid_targets() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.total").unwrap(); + let measure = m.as_measure().unwrap(); + + let cases: Vec<(&str, AggregationType)> = vec![ + ("avg", AggregationType::Avg), + ("min", AggregationType::Min), + ("max", AggregationType::Max), + ("sum", AggregationType::Sum), + ("count_distinct", AggregationType::CountDistinct), + ( + "count_distinct_approx", + AggregationType::CountDistinctApprox, + ), + ]; + for (new_type, expected_agg) in cases { + let patched = measure + .new_patched(Some(new_type.to_string()), vec![]) + .unwrap_or_else(|e| panic!("sum -> {} should succeed: {}", new_type, e)); + assert!( + matches!(patched.kind(), MeasureKind::Aggregated(a) if a.agg_type() == expected_agg), + "sum -> {}: wrong kind", + new_type + ); + assert_eq!(patched.full_name(), "test_measures.total"); + } +} + +#[test] +fn new_patched_avg_to_sum() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.average").unwrap(); + let patched = m + .as_measure() + .unwrap() + .new_patched(Some("sum".to_string()), vec![]) + .unwrap(); + assert!(matches!( + patched.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); +} + +#[test] +fn new_patched_count_distinct_family() { + let ctx = ctx(); + + let cd = ctx.create_measure("test_measures.distinct_count").unwrap(); + let patched = cd + .as_measure() + .unwrap() + .new_patched(Some("count_distinct_approx".to_string()), vec![]) + .unwrap(); + assert!(matches!( + patched.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::CountDistinctApprox + )); + + let cda = ctx.create_measure("test_measures.approx_count").unwrap(); + let patched = cda + .as_measure() + .unwrap() + .new_patched(Some("count_distinct".to_string()), vec![]) + .unwrap(); + assert!(matches!( + patched.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::CountDistinct + )); +} + +// ─── new_patched: invalid type replacements ───────────────────────────────── + +#[test] +fn new_patched_sum_invalid_targets() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.total").unwrap(); + let measure = m.as_measure().unwrap(); + + for invalid in ["number", "count", "runningTotal", "rank", "numberAgg"] { + assert!( + measure + .new_patched(Some(invalid.to_string()), vec![]) + .is_err(), + "sum -> {} should fail", + invalid + ); + } +} + +#[test] +fn new_patched_count_distinct_to_sum_error() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.distinct_count").unwrap(); + assert!(m + .as_measure() + .unwrap() + .new_patched(Some("sum".to_string()), vec![]) + .is_err()); +} + +#[test] +fn new_patched_non_patchable_types() { + let ctx = ctx(); + + let non_patchable = [ + "test_measures.cnt", + "test_measures.calculated", + "test_measures.rank_measure", + "test_measures.running", + ]; + for path in non_patchable { + let m = ctx.create_measure(path).unwrap(); + assert!( + m.as_measure() + .unwrap() + .new_patched(Some("sum".to_string()), vec![]) + .is_err(), + "{} -> sum should fail", + path + ); + } +} + +// ─── new_patched: no type change (None) ───────────────────────────────────── + +#[test] +fn new_patched_none_preserves_kind() { + let ctx = ctx(); + + let m = ctx.create_measure("test_measures.total").unwrap(); + let patched = m.as_measure().unwrap().new_patched(None, vec![]).unwrap(); + assert!(matches!( + patched.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::Sum + )); + + let m = ctx.create_measure("test_measures.cnt").unwrap(); + let patched = m.as_measure().unwrap().new_patched(None, vec![]).unwrap(); + assert!(matches!(patched.kind(), MeasureKind::Count(_))); + + let m = ctx.create_measure("test_measures.calculated").unwrap(); + let patched = m.as_measure().unwrap().new_patched(None, vec![]).unwrap(); + assert!(matches!( + patched.kind(), + MeasureKind::Calculated(c) if c.calc_type() == CalculatedMeasureType::Number + )); + + let m = ctx.create_measure("test_measures.rank_measure").unwrap(); + let patched = m.as_measure().unwrap().new_patched(None, vec![]).unwrap(); + assert!(matches!(patched.kind(), MeasureKind::Rank)); +} + +// ─── new_patched: filter addition validation ──────────────────────────────── + +#[test] +fn new_patched_filters_accepted_for_aggregatable_types() { + let ctx = ctx(); + let filters = get_filter_calls(&ctx); + + let accept_filters = [ + "test_measures.total", + "test_measures.average", + "test_measures.minimum", + "test_measures.maximum", + "test_measures.cnt", + ]; + for path in accept_filters { + let m = ctx.create_measure(path).unwrap(); + let patched = m + .as_measure() + .unwrap() + .new_patched(None, filters.clone()) + .unwrap_or_else(|e| panic!("{} + filters should succeed: {}", path, e)); + assert!( + !patched.measure_filters().is_empty(), + "{}: filters should be added", + path + ); + } +} + +// Fixed: countDistinct/countDistinctApprox now correctly support filters +// via MeasureKind::supports_additional_filters() pattern matching. +#[test] +fn new_patched_count_distinct_accepts_filters() { + let ctx = ctx(); + let filters = get_filter_calls(&ctx); + + for path in ["test_measures.distinct_count", "test_measures.approx_count"] { + let m = ctx.create_measure(path).unwrap(); + assert!( + m.as_measure() + .unwrap() + .new_patched(None, filters.clone()) + .is_ok(), + "{} + filters should be Ok", + path + ); + } +} + +#[test] +fn new_patched_filters_rejected_for_non_aggregatable_types() { + let ctx = ctx(); + let filters = get_filter_calls(&ctx); + + let reject_filters = [ + "test_measures.calculated", + "test_measures.running", + "test_measures.rank_measure", + "test_measures.number_agg", + ]; + for path in reject_filters { + let m = ctx.create_measure(path).unwrap(); + assert!( + m.as_measure() + .unwrap() + .new_patched(None, filters.clone()) + .is_err(), + "{} + filters should fail", + path + ); + } +} + +// ─── new_patched: combined type change + filters ──────────────────────────── + +#[test] +fn new_patched_type_change_with_filters() { + let ctx = ctx(); + let filters = get_filter_calls(&ctx); + + let m = ctx.create_measure("test_measures.total").unwrap(); + let patched = m + .as_measure() + .unwrap() + .new_patched(Some("count_distinct".to_string()), filters) + .unwrap(); + assert!(matches!( + patched.kind(), + MeasureKind::Aggregated(a) if a.agg_type() == AggregationType::CountDistinct + )); + assert!(!patched.measure_filters().is_empty()); +} + +#[test] +fn new_patched_appends_to_existing_filters() { + let ctx = ctx(); + let m = ctx.create_measure("test_measures.filtered_total").unwrap(); + let measure = m.as_measure().unwrap(); + let original_count = measure.measure_filters().len(); + assert!(original_count > 0); + + let new_filters = get_filter_calls(&ctx); + let patched = measure.new_patched(None, new_filters.clone()).unwrap(); + assert_eq!( + patched.measure_filters().len(), + original_count + new_filters.len() + ); +} diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs b/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs index df160eab3a1ad..376256b06b7d3 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/mod.rs @@ -1,4 +1,6 @@ mod common_sql_generation; mod cube_evaluator; +mod dimension_symbol; +mod measure_symbol; mod pre_aggregation_sql_generation; mod utils; diff --git a/rust/cubesqlplanner/cubesqlplanner/src/tests/snapshots/cubesqlplanner__tests__common_sql_generation__measure_switch_cross_join.snap b/rust/cubesqlplanner/cubesqlplanner/src/tests/snapshots/cubesqlplanner__tests__common_sql_generation__measure_switch_cross_join.snap new file mode 100644 index 0000000000000..a826b2d4150e2 --- /dev/null +++ b/rust/cubesqlplanner/cubesqlplanner/src/tests/snapshots/cubesqlplanner__tests__common_sql_generation__measure_switch_cross_join.snap @@ -0,0 +1,52 @@ +--- +source: cubesqlplanner/src/tests/common_sql_generation.rs +expression: sql +--- + WITH +cte_0 AS ( SELECT "orders".currency "orders__currency", date_trunc('year', ("orders".date::timestamptz AT TIME ZONE 'UTC')) "orders__date_year", sum("orders".amount_usd) "orders__amount_usd" + FROM orders AS "orders" + WHERE ("orders".date >= $_0_$::timestamptz AND "orders".date <= $_1_$::timestamptz) AND ("orders".currency = $_2_$) + GROUP BY 1, 2 + ORDER BY 2 ASC), +cte_1 AS ( SELECT "orders".currency "orders__currency", date_trunc('year', ("orders".date::timestamptz AT TIME ZONE 'UTC')) "orders__date_year", sum("orders".amount_eur) "orders__amount_eur" + FROM orders AS "orders" + WHERE ("orders".date >= $_3_$::timestamptz AND "orders".date <= $_4_$::timestamptz) AND ("orders".currency = $_5_$) + GROUP BY 1, 2 + ORDER BY 2 ASC), +cte_2 AS ( SELECT "orders".currency "orders__currency", date_trunc('year', ("orders".date::timestamptz AT TIME ZONE 'UTC')) "orders__date_year", sum("orders".amount_gbp) "orders__amount_gbp" + FROM orders AS "orders" + WHERE ("orders".date >= $_6_$::timestamptz AND "orders".date <= $_7_$::timestamptz) AND ("orders".currency = $_8_$) + GROUP BY 1, 2 + ORDER BY 2 ASC), +cte_3 AS ( SELECT "fk_aggregate_keys"."orders__currency" "orders__currency", "fk_aggregate_keys"."orders__date_year" "orders__date_year", CASE "fk_aggregate_keys"."orders__currency" WHEN 'USD' THEN "q_0"."orders__amount_usd" WHEN 'EUR' THEN "q_1"."orders__amount_eur" WHEN 'GBP' THEN "q_2"."orders__amount_gbp" ELSE "q_0"."orders__amount_usd" END "orders__amount_in_currency" + FROM (SELECT DISTINCT "orders__currency" "orders__currency", "orders__date_year" "orders__date_year" + FROM (SELECT DISTINCT "orders__currency" "orders__currency", "orders__date_year" "orders__date_year" + FROM cte_0 AS "cte_0" + UNION ALL + SELECT DISTINCT "orders__currency" "orders__currency", "orders__date_year" "orders__date_year" + FROM cte_1 AS "cte_1" + UNION ALL + SELECT DISTINCT "orders__currency" "orders__currency", "orders__date_year" "orders__date_year" + FROM cte_2 AS "cte_2") AS "pk_aggregate_keys_source") AS "fk_aggregate_keys" + LEFT JOIN (SELECT * + FROM cte_0 AS "cte_0") AS "q_0" ON (("fk_aggregate_keys"."orders__currency" = "q_0"."orders__currency" OR (("fk_aggregate_keys"."orders__currency" IS NULL) AND ("q_0"."orders__currency" IS NULL)))) AND (("fk_aggregate_keys"."orders__date_year" = "q_0"."orders__date_year" OR (("fk_aggregate_keys"."orders__date_year" IS NULL) AND ("q_0"."orders__date_year" IS NULL)))) + LEFT JOIN (SELECT * + FROM cte_1 AS "cte_1") AS "q_1" ON (("fk_aggregate_keys"."orders__currency" = "q_1"."orders__currency" OR (("fk_aggregate_keys"."orders__currency" IS NULL) AND ("q_1"."orders__currency" IS NULL)))) AND (("fk_aggregate_keys"."orders__date_year" = "q_1"."orders__date_year" OR (("fk_aggregate_keys"."orders__date_year" IS NULL) AND ("q_1"."orders__date_year" IS NULL)))) + LEFT JOIN (SELECT * + FROM cte_2 AS "cte_2") AS "q_2" ON (("fk_aggregate_keys"."orders__currency" = "q_2"."orders__currency" OR (("fk_aggregate_keys"."orders__currency" IS NULL) AND ("q_2"."orders__currency" IS NULL)))) AND (("fk_aggregate_keys"."orders__date_year" = "q_2"."orders__date_year" OR (("fk_aggregate_keys"."orders__date_year" IS NULL) AND ("q_2"."orders__date_year" IS NULL))))) +SELECT "fk_aggregate_keys"."orders__currency" "orders__currency", "fk_aggregate_keys"."orders__date_year" "orders__date_year", "q_0"."orders__amount_usd" "orders__amount_usd", "q_1"."orders__amount_in_currency" "orders__amount_in_currency" +FROM (SELECT DISTINCT "orders__currency" "orders__currency", "orders__date_year" "orders__date_year" +FROM (SELECT "orders".currency "orders__currency", date_trunc('year', ("orders".date::timestamptz AT TIME ZONE 'UTC')) "orders__date_year" +FROM orders AS "orders" +WHERE ("orders".date >= $_11_$::timestamptz AND "orders".date <= $_12_$::timestamptz) +GROUP BY 1, 2 + UNION ALL +SELECT DISTINCT "orders__currency" "orders__currency", "orders__date_year" "orders__date_year" +FROM cte_3 AS "cte_3") AS "pk_aggregate_keys_source") AS "fk_aggregate_keys" +LEFT JOIN (SELECT "orders".currency "orders__currency", date_trunc('year', ("orders".date::timestamptz AT TIME ZONE 'UTC')) "orders__date_year", sum("orders".amount_usd) "orders__amount_usd" +FROM orders AS "orders" +WHERE ("orders".date >= $_9_$::timestamptz AND "orders".date <= $_10_$::timestamptz) +GROUP BY 1, 2) AS "q_0" ON (("fk_aggregate_keys"."orders__currency" = "q_0"."orders__currency" OR (("fk_aggregate_keys"."orders__currency" IS NULL) AND ("q_0"."orders__currency" IS NULL)))) AND (("fk_aggregate_keys"."orders__date_year" = "q_0"."orders__date_year" OR (("fk_aggregate_keys"."orders__date_year" IS NULL) AND ("q_0"."orders__date_year" IS NULL)))) +LEFT JOIN (SELECT * +FROM cte_3 AS "cte_3") AS "q_1" ON (("fk_aggregate_keys"."orders__currency" = "q_1"."orders__currency" OR (("fk_aggregate_keys"."orders__currency" IS NULL) AND ("q_1"."orders__currency" IS NULL)))) AND (("fk_aggregate_keys"."orders__date_year" = "q_1"."orders__date_year" OR (("fk_aggregate_keys"."orders__date_year" IS NULL) AND ("q_1"."orders__date_year" IS NULL)))) +ORDER BY 2 ASC