From a5c702be62b6a22ba6e569f82ebdc6de59afa35a Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Mon, 14 Aug 2023 18:15:47 -0700 Subject: [PATCH] [FEAT] [New Query Planner] Groupby support, aggregation fixes, support for remaining aggregation ops (#1272) This PR adds support for `df.groupby()`, fixes misc. things with aggregations, and adds support for the remaining (non-sum) aggregation ops. This PR builds off of https://github.com/Eventual-Inc/Daft/pull/1257, where @xcharleslin implemented the core meat of this PR (this PR just wires things together and fixes a few minor things). From that PR's description: "Ported over the logic in our existing AggregationPlanBuilder. Groupby-aggregates should now be fully supported (including multi-partition). Additionally, this PR improves on Daft's existing aggregation logic by using semantic IDs in intermediate results, so that redundant intermediates are not computed. E.g. before, getting the Sum and Mean of a column would compute and carry around two copies of the intermediate sum, one for the Sum and one for the Mean. Now, all stages address their required intermediates by semantic ID, eliminating these duplicates." --------- Co-authored-by: Xiayue Charles Lin --- daft/execution/rust_physical_plan_shim.py | 34 +-- daft/logical/rust_logical_plan.py | 18 +- src/daft-core/src/datatypes/field.rs | 8 +- src/daft-dsl/src/expr.rs | 2 +- src/daft-plan/src/builder.rs | 32 ++- src/daft-plan/src/ops/agg.rs | 23 +- src/daft-plan/src/physical_ops/agg.rs | 6 +- src/daft-plan/src/physical_plan.rs | 4 +- src/daft-plan/src/planner.rs | 245 +++++++++++++++------- tests/cookbook/test_aggregations.py | 28 +-- tests/dataframe/test_aggregations.py | 24 +-- 11 files changed, 283 insertions(+), 141 deletions(-) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 92db0888cc..977b3c8953 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -21,23 +21,6 @@ PartitionT = TypeVar("PartitionT") -def local_aggregate( - input: physical_plan.InProgressPhysicalPlan[PartitionT], - agg_exprs: list[PyExpr], - group_by: list[PyExpr], -) -> physical_plan.InProgressPhysicalPlan[PartitionT]: - aggregation_step = execution_step.Aggregate( - to_agg=[Expression._from_pyexpr(pyexpr) for pyexpr in agg_exprs], - group_by=ExpressionsProjection([Expression._from_pyexpr(pyexpr) for pyexpr in group_by]), - ) - - return physical_plan.pipeline_instruction( - child_plan=input, - pipeable_instruction=aggregation_step, - resource_request=ResourceRequest(), - ) - - def tabular_scan( schema: PySchema, file_info_table: PyTable, file_format_config: FileFormatConfig, limit: int ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: @@ -94,6 +77,23 @@ def explode( ) +def local_aggregate( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + agg_exprs: list[PyExpr], + group_by: list[PyExpr], +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + aggregation_step = execution_step.Aggregate( + to_agg=[Expression._from_pyexpr(pyexpr) for pyexpr in agg_exprs], + group_by=ExpressionsProjection([Expression._from_pyexpr(pyexpr) for pyexpr in group_by]), + ) + + return physical_plan.pipeline_instruction( + child_plan=input, + pipeable_instruction=aggregation_step, + resource_request=ResourceRequest(), + ) + + def sort( input: physical_plan.InProgressPhysicalPlan[PartitionT], sort_by: list[PyExpr], diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index 6bc7bfd334..78932bb7f2 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -174,10 +174,24 @@ def agg( for expr, op in to_agg: if op == "sum": exprs.append(expr._sum()) + elif op == "count": + exprs.append(expr._count()) + elif op == "min": + exprs.append(expr._min()) + elif op == "max": + exprs.append(expr._max()) + elif op == "mean": + exprs.append(expr._mean()) + elif op == "list": + exprs.append(expr._agg_list()) + elif op == "concat": + exprs.append(expr._agg_concat()) else: - raise NotImplementedError() + raise NotImplementedError(f"Aggregation {op} is not implemented.") - builder = self._builder.aggregate([expr._expr for expr in exprs]) + builder = self._builder.aggregate( + [expr._expr for expr in exprs], group_by.to_inner_py_exprs() if group_by is not None else [] + ) return RustLogicalPlanBuilder(builder) def join( # type: ignore[override] diff --git a/src/daft-core/src/datatypes/field.rs b/src/daft-core/src/datatypes/field.rs index 3b9431bbb4..229bf7b683 100644 --- a/src/daft-core/src/datatypes/field.rs +++ b/src/daft-core/src/datatypes/field.rs @@ -19,17 +19,19 @@ pub struct Field { #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] pub struct FieldID { - pub id: String, + pub id: Arc, } impl FieldID { - pub fn new>(id: S) -> Self { + pub fn new>>(id: S) -> Self { Self { id: id.into() } } /// Create a Field ID directly from a real column name. /// Performs sanitization on the name so it can be composed. - pub fn from_name(name: String) -> Self { + pub fn from_name>(name: S) -> Self { + let name: String = name.into(); + // Escape parentheses within a string, // since we will use parentheses as delimiters in our semantic expression IDs. let sanitized = name diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index e54cbf6ba5..1684ebc67b 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -293,7 +293,7 @@ impl Expr { Function { func, inputs } => { let inputs = inputs .iter() - .map(|expr| expr.semantic_id(schema).id) + .map(|expr| expr.semantic_id(schema).id.to_string()) .collect::>() .join(", "); // TODO: check for function idempotency here. diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 06b6a6c14b..2a87cb0b9a 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use common_error::DaftResult; + use crate::{logical_plan::LogicalPlan, ResourceRequest}; #[cfg(feature = "python")] @@ -177,7 +179,11 @@ impl LogicalPlanBuilder { Ok(logical_plan_builder) } - pub fn aggregate(&self, agg_exprs: Vec) -> PyResult { + pub fn aggregate( + &self, + agg_exprs: Vec, + groupby_exprs: Vec, + ) -> PyResult { use crate::ops::Aggregate; let agg_exprs = agg_exprs .iter() @@ -189,7 +195,29 @@ impl LogicalPlanBuilder { ))), }) .collect::>>()?; - let logical_plan: LogicalPlan = Aggregate::new(agg_exprs, self.plan.clone()).into(); + let groupby_exprs = groupby_exprs + .iter() + .map(|expr| expr.clone().into()) + .collect::>(); + + let input_schema = self.plan.schema(); + let fields = groupby_exprs + .iter() + .map(|expr| expr.to_field(&input_schema)) + .chain( + agg_exprs + .iter() + .map(|agg_expr| agg_expr.to_field(&input_schema)), + ) + .collect::>>()?; + let output_schema = Schema::new(fields)?; + let logical_plan: LogicalPlan = Aggregate::new( + agg_exprs, + groupby_exprs, + output_schema.into(), + self.plan.clone(), + ) + .into(); let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into()); Ok(logical_plan_builder) } diff --git a/src/daft-plan/src/ops/agg.rs b/src/daft-plan/src/ops/agg.rs index 641629c9a0..c6178d1a74 100644 --- a/src/daft-plan/src/ops/agg.rs +++ b/src/daft-plan/src/ops/agg.rs @@ -11,20 +11,25 @@ pub struct Aggregate { pub aggregations: Vec, /// Grouping to apply. - pub group_by: Vec, + pub groupby: Vec, + + pub output_schema: SchemaRef, // Upstream node. pub input: Arc, } impl Aggregate { - pub(crate) fn new(aggregations: Vec, input: Arc) -> Self { - // TEMP: No groupbys supported for now. - let group_by: Vec = vec![]; - + pub(crate) fn new( + aggregations: Vec, + groupby: Vec, + output_schema: SchemaRef, + input: Arc, + ) -> Self { Self { aggregations, - group_by, + groupby, + output_schema, input, } } @@ -33,7 +38,7 @@ impl Aggregate { let source_schema = self.input.schema(); let fields = self - .group_by + .groupby .iter() .map(|expr| expr.to_field(&source_schema).unwrap()) .chain( @@ -48,8 +53,8 @@ impl Aggregate { pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Aggregation: {:?}", self.aggregations)); - if !self.group_by.is_empty() { - res.push(format!(" Group by: {:?}", self.group_by)); + if !self.groupby.is_empty() { + res.push(format!(" Group by: {:?}", self.groupby)); } res.push(format!(" Output schema: {}", self.schema().short_string())); res diff --git a/src/daft-plan/src/physical_ops/agg.rs b/src/daft-plan/src/physical_ops/agg.rs index 4d948ab489..327e7f5b74 100644 --- a/src/daft-plan/src/physical_ops/agg.rs +++ b/src/daft-plan/src/physical_ops/agg.rs @@ -11,7 +11,7 @@ pub struct Aggregate { pub aggregations: Vec, /// Grouping to apply. - pub group_by: Vec, + pub groupby: Vec, // Upstream node. pub input: Arc, @@ -21,11 +21,11 @@ impl Aggregate { pub(crate) fn new( input: Arc, aggregations: Vec, - group_by: Vec, + groupby: Vec, ) -> Self { Self { aggregations, - group_by, + groupby, input, } } diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index efbfb603ac..7a909b3e82 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -373,7 +373,7 @@ impl PhysicalPlan { } PhysicalPlan::Aggregate(Aggregate { aggregations, - group_by, + groupby, input, .. }) => { @@ -382,7 +382,7 @@ impl PhysicalPlan { .iter() .map(|agg_expr| PyExpr::from(Expr::Agg(agg_expr.clone()))) .collect(); - let groupbys_as_pyexprs: Vec = group_by + let groupbys_as_pyexprs: Vec = groupby .iter() .map(|expr| PyExpr::from(expr.clone())) .collect(); diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 854310ff7b..36024e4632 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -1,5 +1,5 @@ -use std::cmp::max; use std::sync::Arc; +use std::{cmp::max, collections::HashMap}; use common_error::DaftResult; use daft_dsl::Expr; @@ -192,97 +192,190 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { } LogicalPlan::Aggregate(LogicalAggregate { aggregations, - group_by, + groupby, input, + .. }) => { - use daft_dsl::AggExpr::*; - let result_plan = plan(input)?; - - if !group_by.is_empty() { - unimplemented!("{:?}", group_by); - } + use daft_dsl::AggExpr::{self, *}; + use daft_dsl::Expr::Column; + let input_plan = plan(input)?; let num_input_partitions = logical_plan.partition_spec().num_partitions; let result_plan = match num_input_partitions { 1 => PhysicalPlan::Aggregate(Aggregate::new( - result_plan.into(), + input_plan.into(), aggregations.clone(), - vec![], + groupby.clone(), )), _ => { - // Resolve and assign intermediate names for the aggregations. let schema = logical_plan.schema(); - let intermediate_names: Vec = aggregations - .iter() - .map(|agg_expr| agg_expr.semantic_id(&schema)) - .collect(); - let first_stage_aggs: Vec = aggregations - .iter() - .zip(intermediate_names.iter()) - .map(|(agg_expr, field_id)| match agg_expr { - Count(e) => Count(e.alias(field_id.id.clone()).into()), - Sum(e) => Sum(e.alias(field_id.id.clone()).into()), - Mean(e) => Mean(e.alias(field_id.id.clone()).into()), - Min(e) => Min(e.alias(field_id.id.clone()).into()), - Max(e) => Max(e.alias(field_id.id.clone()).into()), - List(e) => List(e.alias(field_id.id.clone()).into()), - Concat(e) => Concat(e.alias(field_id.id.clone()).into()), - }) - .collect(); + // Aggregations to apply in the first and second stages. + // Semantic column name -> AggExpr + let mut first_stage_aggs: HashMap, AggExpr> = HashMap::new(); + let mut second_stage_aggs: HashMap, AggExpr> = HashMap::new(); + // Project the aggregation results to their final output names + let mut final_exprs: Vec = groupby.clone(); - let second_stage_aggs: Vec = intermediate_names - .iter() - .zip(schema.fields.keys()) - .zip(aggregations.iter()) - .map(|((field_id, original_name), agg_expr)| match agg_expr { - Count(_) => Count( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - Sum(_) => Sum(daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into()), - Mean(_) => Mean( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - Min(_) => Min(daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into()), - Max(_) => Max(daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into()), - List(_) => List( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - Concat(_) => Concat( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - }) - .collect(); + for agg_expr in aggregations { + let output_name = agg_expr.name().unwrap(); + match agg_expr { + Count(e) => { + let count_id = agg_expr.semantic_id(&schema).id; + let sum_of_count_id = + Sum(Column(count_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(count_id.clone()) + .or_insert(Count(e.alias(count_id.clone()).clone().into())); + second_stage_aggs + .entry(sum_of_count_id.clone()) + .or_insert(Sum(Column(count_id.clone()) + .alias(sum_of_count_id.clone()) + .into())); + final_exprs + .push(Column(sum_of_count_id.clone()).alias(output_name)); + } + Sum(e) => { + let sum_id = agg_expr.semantic_id(&schema).id; + let sum_of_sum_id = + Sum(Column(sum_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(sum_id.clone()) + .or_insert(Sum(e.alias(sum_id.clone()).clone().into())); + second_stage_aggs + .entry(sum_of_sum_id.clone()) + .or_insert(Sum(Column(sum_id.clone()) + .alias(sum_of_sum_id.clone()) + .into())); + final_exprs.push(Column(sum_of_sum_id.clone()).alias(output_name)); + } + Mean(e) => { + let sum_id = Sum(e.clone()).semantic_id(&schema).id; + let count_id = Count(e.clone()).semantic_id(&schema).id; + let sum_of_sum_id = + Sum(Column(sum_id.clone()).into()).semantic_id(&schema).id; + let sum_of_count_id = + Sum(Column(count_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(sum_id.clone()) + .or_insert(Sum(e.alias(sum_id.clone()).clone().into())); + first_stage_aggs + .entry(count_id.clone()) + .or_insert(Count(e.alias(count_id.clone()).clone().into())); + second_stage_aggs + .entry(sum_of_sum_id.clone()) + .or_insert(Sum(Column(sum_id.clone()) + .alias(sum_of_sum_id.clone()) + .into())); + second_stage_aggs + .entry(sum_of_count_id.clone()) + .or_insert(Sum(Column(count_id.clone()) + .alias(sum_of_count_id.clone()) + .into())); + final_exprs.push( + (Column(sum_of_sum_id.clone()) + / Column(sum_of_count_id.clone())) + .alias(output_name), + ); + } + Min(e) => { + let min_id = agg_expr.semantic_id(&schema).id; + let min_of_min_id = + Min(Column(min_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(min_id.clone()) + .or_insert(Min(e.alias(min_id.clone()).clone().into())); + second_stage_aggs + .entry(min_of_min_id.clone()) + .or_insert(Min(Column(min_id.clone()) + .alias(min_of_min_id.clone()) + .into())); + final_exprs.push(Column(min_of_min_id.clone()).alias(output_name)); + } + Max(e) => { + let max_id = agg_expr.semantic_id(&schema).id; + let max_of_max_id = + Max(Column(max_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(max_id.clone()) + .or_insert(Max(e.alias(max_id.clone()).clone().into())); + second_stage_aggs + .entry(max_of_max_id.clone()) + .or_insert(Max(Column(max_id.clone()) + .alias(max_of_max_id.clone()) + .into())); + final_exprs.push(Column(max_of_max_id.clone()).alias(output_name)); + } + List(e) => { + let list_id = agg_expr.semantic_id(&schema).id; + let concat_of_list_id = Concat(Column(list_id.clone()).into()) + .semantic_id(&schema) + .id; + first_stage_aggs + .entry(list_id.clone()) + .or_insert(List(e.alias(list_id.clone()).clone().into())); + second_stage_aggs + .entry(concat_of_list_id.clone()) + .or_insert(Concat( + Column(list_id.clone()) + .alias(concat_of_list_id.clone()) + .into(), + )); + final_exprs + .push(Column(concat_of_list_id.clone()).alias(output_name)); + } + Concat(e) => { + let concat_id = agg_expr.semantic_id(&schema).id; + let concat_of_concat_id = Concat(Column(concat_id.clone()).into()) + .semantic_id(&schema) + .id; + first_stage_aggs + .entry(concat_id.clone()) + .or_insert(Concat(e.alias(concat_id.clone()).clone().into())); + second_stage_aggs + .entry(concat_of_concat_id.clone()) + .or_insert(Concat( + Column(concat_id.clone()) + .alias(concat_of_concat_id.clone()) + .into(), + )); + final_exprs + .push(Column(concat_of_concat_id.clone()).alias(output_name)); + } + } + } - let result_plan = PhysicalPlan::Aggregate(Aggregate::new( - result_plan.into(), - first_stage_aggs, - vec![], + let first_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( + input_plan.into(), + first_stage_aggs.values().cloned().collect(), + groupby.clone(), )); - let result_plan = PhysicalPlan::Coalesce(Coalesce::new( - result_plan.into(), - num_input_partitions, - 1, + let gather_plan = if groupby.is_empty() { + PhysicalPlan::Coalesce(Coalesce::new( + first_stage_agg.into(), + num_input_partitions, + 1, + )) + } else { + let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + num_input_partitions, + groupby.clone(), + first_stage_agg.into(), + )); + PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())) + }; + + let _second_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( + gather_plan.into(), + second_stage_aggs.values().cloned().collect(), + groupby.clone(), )); - PhysicalPlan::Aggregate(Aggregate::new( - result_plan.into(), - second_stage_aggs, - vec![], + + PhysicalPlan::Project(Project::new( + final_exprs, + Default::default(), + _second_stage_agg.into(), )) } }; diff --git a/tests/cookbook/test_aggregations.py b/tests/cookbook/test_aggregations.py index e1ad6b39d6..9206c72404 100644 --- a/tests/cookbook/test_aggregations.py +++ b/tests/cookbook/test_aggregations.py @@ -7,7 +7,7 @@ from tests.conftest import assert_df_equals -def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Sums across an entire column for the entire table""" daft_df = daft_df.repartition(repartition_nparts).sum(col("Unique Key").alias("unique_key_sum")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -17,7 +17,7 @@ def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_sum") -def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Averages across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).mean(col("Unique Key").alias("unique_key_mean")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -27,7 +27,7 @@ def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_mean") -def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """min across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).min(col("Unique Key").alias("unique_key_min")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -37,7 +37,7 @@ def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_min") -def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """max across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).max(col("Unique Key").alias("unique_key_max")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -47,7 +47,7 @@ def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_max") -def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """count a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).count(col("Unique Key").alias("unique_key_count")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -58,7 +58,7 @@ def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_count") -def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """list agg a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).agg_list(col("Unique Key").alias("unique_key_list")).collect() unique_key_list = service_requests_csv_pd_df["Unique Key"].to_list() @@ -68,7 +68,7 @@ def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts): assert set(result_list[0]) == set(unique_key_list) -def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Averages across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).agg( [ @@ -92,7 +92,7 @@ def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_mean") -def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Sums across an entire column for the entire table filtered by a certain condition""" daft_df = ( daft_df.repartition(repartition_nparts) @@ -119,7 +119,7 @@ def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """Sums across groups""" daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).sum(col("Unique Key")) service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).sum("Unique Key").reset_index() @@ -134,7 +134,7 @@ def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """Sums across groups""" daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).mean(col("Unique Key")) service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).mean("Unique Key").reset_index() @@ -149,7 +149,7 @@ def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, k pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """count across groups""" daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).count() service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).count().reset_index() @@ -167,7 +167,7 @@ def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """min across groups""" daft_df = ( daft_df.repartition(repartition_nparts) @@ -188,7 +188,7 @@ def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """max across groups""" daft_df = ( daft_df.repartition(repartition_nparts) @@ -209,7 +209,7 @@ def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke pytest.param(["Borough", "Complaint Type"], id="NumGroupSortKeys:2"), ], ) -def test_sum_groupby_sorted(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_sum_groupby_sorted(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """Test sorting after a groupby""" daft_df = ( daft_df.repartition(repartition_nparts) diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 1667204fbb..29d33e64c8 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_agg_global(repartition_nparts): +def test_agg_global(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [1, 2, 3], @@ -46,7 +46,7 @@ def test_agg_global(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_agg_global_all_null(repartition_nparts): +def test_agg_global_all_null(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2, 3], @@ -82,7 +82,7 @@ def test_agg_global_all_null(repartition_nparts): assert pa.Table.from_pydict(daft_cols) == pa.Table.from_pydict(expected) -def test_agg_global_empty(): +def test_agg_global_empty(use_new_planner): daft_df = daft.from_pydict( { "id": [0], @@ -119,7 +119,7 @@ def test_agg_global_empty(): @pytest.mark.parametrize("repartition_nparts", [1, 2, 7]) -def test_agg_groupby(repartition_nparts): +def test_agg_groupby(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "group": [1, 1, 1, 2, 2, 2], @@ -164,7 +164,7 @@ def test_agg_groupby(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) -def test_agg_groupby_all_null(repartition_nparts): +def test_agg_groupby_all_null(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2, 3, 4], @@ -203,7 +203,7 @@ def test_agg_groupby_all_null(repartition_nparts): ) -def test_agg_groupby_null_type_column(): +def test_agg_groupby_null_type_column(use_new_planner): daft_df = daft.from_pydict( { "id": [1, 2, 3, 4], @@ -222,7 +222,7 @@ def test_agg_groupby_null_type_column(): @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) -def test_null_groupby_keys(repartition_nparts): +def test_null_groupby_keys(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2, 3, 4], @@ -252,7 +252,7 @@ def test_null_groupby_keys(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_all_null_groupby_keys(repartition_nparts): +def test_all_null_groupby_keys(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2], @@ -281,7 +281,7 @@ def test_all_null_groupby_keys(repartition_nparts): assert set(daft_cols["list"][0]) == {1, 2, 3} -def test_null_type_column_groupby_keys(): +def test_null_type_column_groupby_keys(use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2], @@ -294,7 +294,7 @@ def test_null_type_column_groupby_keys(): daft_df.groupby(col("group")) -def test_agg_groupby_empty(): +def test_agg_groupby_empty(use_new_planner): daft_df = daft.from_pydict( { "id": [0], @@ -337,7 +337,7 @@ class CustomObject: val: int -def test_agg_pyobjects(): +def test_agg_pyobjects(use_new_planner): objects = [CustomObject(val=0), None, CustomObject(val=1)] df = daft.from_pydict({"objs": objects}) df = df.into_partitions(2) @@ -354,7 +354,7 @@ def test_agg_pyobjects(): assert res["list"] == [objects] -def test_groupby_agg_pyobjects(): +def test_groupby_agg_pyobjects(use_new_planner): objects = [CustomObject(val=0), CustomObject(val=1), None, None, CustomObject(val=2)] df = daft.from_pydict({"objects": objects, "groups": [1, 2, 1, 2, 1]}) df = df.into_partitions(2)