diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 92db0888cc..977b3c8953 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -21,23 +21,6 @@ PartitionT = TypeVar("PartitionT") -def local_aggregate( - input: physical_plan.InProgressPhysicalPlan[PartitionT], - agg_exprs: list[PyExpr], - group_by: list[PyExpr], -) -> physical_plan.InProgressPhysicalPlan[PartitionT]: - aggregation_step = execution_step.Aggregate( - to_agg=[Expression._from_pyexpr(pyexpr) for pyexpr in agg_exprs], - group_by=ExpressionsProjection([Expression._from_pyexpr(pyexpr) for pyexpr in group_by]), - ) - - return physical_plan.pipeline_instruction( - child_plan=input, - pipeable_instruction=aggregation_step, - resource_request=ResourceRequest(), - ) - - def tabular_scan( schema: PySchema, file_info_table: PyTable, file_format_config: FileFormatConfig, limit: int ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: @@ -94,6 +77,23 @@ def explode( ) +def local_aggregate( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + agg_exprs: list[PyExpr], + group_by: list[PyExpr], +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + aggregation_step = execution_step.Aggregate( + to_agg=[Expression._from_pyexpr(pyexpr) for pyexpr in agg_exprs], + group_by=ExpressionsProjection([Expression._from_pyexpr(pyexpr) for pyexpr in group_by]), + ) + + return physical_plan.pipeline_instruction( + child_plan=input, + pipeable_instruction=aggregation_step, + resource_request=ResourceRequest(), + ) + + def sort( input: physical_plan.InProgressPhysicalPlan[PartitionT], sort_by: list[PyExpr], diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index 6bc7bfd334..78932bb7f2 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -174,10 +174,24 @@ def agg( for expr, op in to_agg: if op == "sum": exprs.append(expr._sum()) + elif op == "count": + exprs.append(expr._count()) + elif op == "min": + exprs.append(expr._min()) + elif op == "max": + exprs.append(expr._max()) + elif op == "mean": + exprs.append(expr._mean()) + elif op == "list": + exprs.append(expr._agg_list()) + elif op == "concat": + exprs.append(expr._agg_concat()) else: - raise NotImplementedError() + raise NotImplementedError(f"Aggregation {op} is not implemented.") - builder = self._builder.aggregate([expr._expr for expr in exprs]) + builder = self._builder.aggregate( + [expr._expr for expr in exprs], group_by.to_inner_py_exprs() if group_by is not None else [] + ) return RustLogicalPlanBuilder(builder) def join( # type: ignore[override] diff --git a/src/daft-core/src/datatypes/field.rs b/src/daft-core/src/datatypes/field.rs index 3b9431bbb4..229bf7b683 100644 --- a/src/daft-core/src/datatypes/field.rs +++ b/src/daft-core/src/datatypes/field.rs @@ -19,17 +19,19 @@ pub struct Field { #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] pub struct FieldID { - pub id: String, + pub id: Arc, } impl FieldID { - pub fn new>(id: S) -> Self { + pub fn new>>(id: S) -> Self { Self { id: id.into() } } /// Create a Field ID directly from a real column name. /// Performs sanitization on the name so it can be composed. - pub fn from_name(name: String) -> Self { + pub fn from_name>(name: S) -> Self { + let name: String = name.into(); + // Escape parentheses within a string, // since we will use parentheses as delimiters in our semantic expression IDs. let sanitized = name diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index e54cbf6ba5..1684ebc67b 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -293,7 +293,7 @@ impl Expr { Function { func, inputs } => { let inputs = inputs .iter() - .map(|expr| expr.semantic_id(schema).id) + .map(|expr| expr.semantic_id(schema).id.to_string()) .collect::>() .join(", "); // TODO: check for function idempotency here. diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 06b6a6c14b..2a87cb0b9a 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use common_error::DaftResult; + use crate::{logical_plan::LogicalPlan, ResourceRequest}; #[cfg(feature = "python")] @@ -177,7 +179,11 @@ impl LogicalPlanBuilder { Ok(logical_plan_builder) } - pub fn aggregate(&self, agg_exprs: Vec) -> PyResult { + pub fn aggregate( + &self, + agg_exprs: Vec, + groupby_exprs: Vec, + ) -> PyResult { use crate::ops::Aggregate; let agg_exprs = agg_exprs .iter() @@ -189,7 +195,29 @@ impl LogicalPlanBuilder { ))), }) .collect::>>()?; - let logical_plan: LogicalPlan = Aggregate::new(agg_exprs, self.plan.clone()).into(); + let groupby_exprs = groupby_exprs + .iter() + .map(|expr| expr.clone().into()) + .collect::>(); + + let input_schema = self.plan.schema(); + let fields = groupby_exprs + .iter() + .map(|expr| expr.to_field(&input_schema)) + .chain( + agg_exprs + .iter() + .map(|agg_expr| agg_expr.to_field(&input_schema)), + ) + .collect::>>()?; + let output_schema = Schema::new(fields)?; + let logical_plan: LogicalPlan = Aggregate::new( + agg_exprs, + groupby_exprs, + output_schema.into(), + self.plan.clone(), + ) + .into(); let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into()); Ok(logical_plan_builder) } diff --git a/src/daft-plan/src/ops/agg.rs b/src/daft-plan/src/ops/agg.rs index 641629c9a0..c6178d1a74 100644 --- a/src/daft-plan/src/ops/agg.rs +++ b/src/daft-plan/src/ops/agg.rs @@ -11,20 +11,25 @@ pub struct Aggregate { pub aggregations: Vec, /// Grouping to apply. - pub group_by: Vec, + pub groupby: Vec, + + pub output_schema: SchemaRef, // Upstream node. pub input: Arc, } impl Aggregate { - pub(crate) fn new(aggregations: Vec, input: Arc) -> Self { - // TEMP: No groupbys supported for now. - let group_by: Vec = vec![]; - + pub(crate) fn new( + aggregations: Vec, + groupby: Vec, + output_schema: SchemaRef, + input: Arc, + ) -> Self { Self { aggregations, - group_by, + groupby, + output_schema, input, } } @@ -33,7 +38,7 @@ impl Aggregate { let source_schema = self.input.schema(); let fields = self - .group_by + .groupby .iter() .map(|expr| expr.to_field(&source_schema).unwrap()) .chain( @@ -48,8 +53,8 @@ impl Aggregate { pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Aggregation: {:?}", self.aggregations)); - if !self.group_by.is_empty() { - res.push(format!(" Group by: {:?}", self.group_by)); + if !self.groupby.is_empty() { + res.push(format!(" Group by: {:?}", self.groupby)); } res.push(format!(" Output schema: {}", self.schema().short_string())); res diff --git a/src/daft-plan/src/physical_ops/agg.rs b/src/daft-plan/src/physical_ops/agg.rs index 4d948ab489..327e7f5b74 100644 --- a/src/daft-plan/src/physical_ops/agg.rs +++ b/src/daft-plan/src/physical_ops/agg.rs @@ -11,7 +11,7 @@ pub struct Aggregate { pub aggregations: Vec, /// Grouping to apply. - pub group_by: Vec, + pub groupby: Vec, // Upstream node. pub input: Arc, @@ -21,11 +21,11 @@ impl Aggregate { pub(crate) fn new( input: Arc, aggregations: Vec, - group_by: Vec, + groupby: Vec, ) -> Self { Self { aggregations, - group_by, + groupby, input, } } diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index efbfb603ac..7a909b3e82 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -373,7 +373,7 @@ impl PhysicalPlan { } PhysicalPlan::Aggregate(Aggregate { aggregations, - group_by, + groupby, input, .. }) => { @@ -382,7 +382,7 @@ impl PhysicalPlan { .iter() .map(|agg_expr| PyExpr::from(Expr::Agg(agg_expr.clone()))) .collect(); - let groupbys_as_pyexprs: Vec = group_by + let groupbys_as_pyexprs: Vec = groupby .iter() .map(|expr| PyExpr::from(expr.clone())) .collect(); diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 854310ff7b..36024e4632 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -1,5 +1,5 @@ -use std::cmp::max; use std::sync::Arc; +use std::{cmp::max, collections::HashMap}; use common_error::DaftResult; use daft_dsl::Expr; @@ -192,97 +192,190 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { } LogicalPlan::Aggregate(LogicalAggregate { aggregations, - group_by, + groupby, input, + .. }) => { - use daft_dsl::AggExpr::*; - let result_plan = plan(input)?; - - if !group_by.is_empty() { - unimplemented!("{:?}", group_by); - } + use daft_dsl::AggExpr::{self, *}; + use daft_dsl::Expr::Column; + let input_plan = plan(input)?; let num_input_partitions = logical_plan.partition_spec().num_partitions; let result_plan = match num_input_partitions { 1 => PhysicalPlan::Aggregate(Aggregate::new( - result_plan.into(), + input_plan.into(), aggregations.clone(), - vec![], + groupby.clone(), )), _ => { - // Resolve and assign intermediate names for the aggregations. let schema = logical_plan.schema(); - let intermediate_names: Vec = aggregations - .iter() - .map(|agg_expr| agg_expr.semantic_id(&schema)) - .collect(); - let first_stage_aggs: Vec = aggregations - .iter() - .zip(intermediate_names.iter()) - .map(|(agg_expr, field_id)| match agg_expr { - Count(e) => Count(e.alias(field_id.id.clone()).into()), - Sum(e) => Sum(e.alias(field_id.id.clone()).into()), - Mean(e) => Mean(e.alias(field_id.id.clone()).into()), - Min(e) => Min(e.alias(field_id.id.clone()).into()), - Max(e) => Max(e.alias(field_id.id.clone()).into()), - List(e) => List(e.alias(field_id.id.clone()).into()), - Concat(e) => Concat(e.alias(field_id.id.clone()).into()), - }) - .collect(); + // Aggregations to apply in the first and second stages. + // Semantic column name -> AggExpr + let mut first_stage_aggs: HashMap, AggExpr> = HashMap::new(); + let mut second_stage_aggs: HashMap, AggExpr> = HashMap::new(); + // Project the aggregation results to their final output names + let mut final_exprs: Vec = groupby.clone(); - let second_stage_aggs: Vec = intermediate_names - .iter() - .zip(schema.fields.keys()) - .zip(aggregations.iter()) - .map(|((field_id, original_name), agg_expr)| match agg_expr { - Count(_) => Count( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - Sum(_) => Sum(daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into()), - Mean(_) => Mean( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - Min(_) => Min(daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into()), - Max(_) => Max(daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into()), - List(_) => List( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - Concat(_) => Concat( - daft_dsl::Expr::Column(field_id.id.clone().into()) - .alias(&**original_name) - .into(), - ), - }) - .collect(); + for agg_expr in aggregations { + let output_name = agg_expr.name().unwrap(); + match agg_expr { + Count(e) => { + let count_id = agg_expr.semantic_id(&schema).id; + let sum_of_count_id = + Sum(Column(count_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(count_id.clone()) + .or_insert(Count(e.alias(count_id.clone()).clone().into())); + second_stage_aggs + .entry(sum_of_count_id.clone()) + .or_insert(Sum(Column(count_id.clone()) + .alias(sum_of_count_id.clone()) + .into())); + final_exprs + .push(Column(sum_of_count_id.clone()).alias(output_name)); + } + Sum(e) => { + let sum_id = agg_expr.semantic_id(&schema).id; + let sum_of_sum_id = + Sum(Column(sum_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(sum_id.clone()) + .or_insert(Sum(e.alias(sum_id.clone()).clone().into())); + second_stage_aggs + .entry(sum_of_sum_id.clone()) + .or_insert(Sum(Column(sum_id.clone()) + .alias(sum_of_sum_id.clone()) + .into())); + final_exprs.push(Column(sum_of_sum_id.clone()).alias(output_name)); + } + Mean(e) => { + let sum_id = Sum(e.clone()).semantic_id(&schema).id; + let count_id = Count(e.clone()).semantic_id(&schema).id; + let sum_of_sum_id = + Sum(Column(sum_id.clone()).into()).semantic_id(&schema).id; + let sum_of_count_id = + Sum(Column(count_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(sum_id.clone()) + .or_insert(Sum(e.alias(sum_id.clone()).clone().into())); + first_stage_aggs + .entry(count_id.clone()) + .or_insert(Count(e.alias(count_id.clone()).clone().into())); + second_stage_aggs + .entry(sum_of_sum_id.clone()) + .or_insert(Sum(Column(sum_id.clone()) + .alias(sum_of_sum_id.clone()) + .into())); + second_stage_aggs + .entry(sum_of_count_id.clone()) + .or_insert(Sum(Column(count_id.clone()) + .alias(sum_of_count_id.clone()) + .into())); + final_exprs.push( + (Column(sum_of_sum_id.clone()) + / Column(sum_of_count_id.clone())) + .alias(output_name), + ); + } + Min(e) => { + let min_id = agg_expr.semantic_id(&schema).id; + let min_of_min_id = + Min(Column(min_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(min_id.clone()) + .or_insert(Min(e.alias(min_id.clone()).clone().into())); + second_stage_aggs + .entry(min_of_min_id.clone()) + .or_insert(Min(Column(min_id.clone()) + .alias(min_of_min_id.clone()) + .into())); + final_exprs.push(Column(min_of_min_id.clone()).alias(output_name)); + } + Max(e) => { + let max_id = agg_expr.semantic_id(&schema).id; + let max_of_max_id = + Max(Column(max_id.clone()).into()).semantic_id(&schema).id; + first_stage_aggs + .entry(max_id.clone()) + .or_insert(Max(e.alias(max_id.clone()).clone().into())); + second_stage_aggs + .entry(max_of_max_id.clone()) + .or_insert(Max(Column(max_id.clone()) + .alias(max_of_max_id.clone()) + .into())); + final_exprs.push(Column(max_of_max_id.clone()).alias(output_name)); + } + List(e) => { + let list_id = agg_expr.semantic_id(&schema).id; + let concat_of_list_id = Concat(Column(list_id.clone()).into()) + .semantic_id(&schema) + .id; + first_stage_aggs + .entry(list_id.clone()) + .or_insert(List(e.alias(list_id.clone()).clone().into())); + second_stage_aggs + .entry(concat_of_list_id.clone()) + .or_insert(Concat( + Column(list_id.clone()) + .alias(concat_of_list_id.clone()) + .into(), + )); + final_exprs + .push(Column(concat_of_list_id.clone()).alias(output_name)); + } + Concat(e) => { + let concat_id = agg_expr.semantic_id(&schema).id; + let concat_of_concat_id = Concat(Column(concat_id.clone()).into()) + .semantic_id(&schema) + .id; + first_stage_aggs + .entry(concat_id.clone()) + .or_insert(Concat(e.alias(concat_id.clone()).clone().into())); + second_stage_aggs + .entry(concat_of_concat_id.clone()) + .or_insert(Concat( + Column(concat_id.clone()) + .alias(concat_of_concat_id.clone()) + .into(), + )); + final_exprs + .push(Column(concat_of_concat_id.clone()).alias(output_name)); + } + } + } - let result_plan = PhysicalPlan::Aggregate(Aggregate::new( - result_plan.into(), - first_stage_aggs, - vec![], + let first_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( + input_plan.into(), + first_stage_aggs.values().cloned().collect(), + groupby.clone(), )); - let result_plan = PhysicalPlan::Coalesce(Coalesce::new( - result_plan.into(), - num_input_partitions, - 1, + let gather_plan = if groupby.is_empty() { + PhysicalPlan::Coalesce(Coalesce::new( + first_stage_agg.into(), + num_input_partitions, + 1, + )) + } else { + let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + num_input_partitions, + groupby.clone(), + first_stage_agg.into(), + )); + PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())) + }; + + let _second_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( + gather_plan.into(), + second_stage_aggs.values().cloned().collect(), + groupby.clone(), )); - PhysicalPlan::Aggregate(Aggregate::new( - result_plan.into(), - second_stage_aggs, - vec![], + + PhysicalPlan::Project(Project::new( + final_exprs, + Default::default(), + _second_stage_agg.into(), )) } }; diff --git a/tests/cookbook/test_aggregations.py b/tests/cookbook/test_aggregations.py index e1ad6b39d6..9206c72404 100644 --- a/tests/cookbook/test_aggregations.py +++ b/tests/cookbook/test_aggregations.py @@ -7,7 +7,7 @@ from tests.conftest import assert_df_equals -def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Sums across an entire column for the entire table""" daft_df = daft_df.repartition(repartition_nparts).sum(col("Unique Key").alias("unique_key_sum")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -17,7 +17,7 @@ def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_sum") -def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Averages across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).mean(col("Unique Key").alias("unique_key_mean")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -27,7 +27,7 @@ def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_mean") -def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """min across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).min(col("Unique Key").alias("unique_key_min")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -37,7 +37,7 @@ def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_min") -def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """max across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).max(col("Unique Key").alias("unique_key_max")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -47,7 +47,7 @@ def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_max") -def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """count a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).count(col("Unique Key").alias("unique_key_count")) service_requests_csv_pd_df = pd.DataFrame.from_records( @@ -58,7 +58,7 @@ def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_count") -def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """list agg a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).agg_list(col("Unique Key").alias("unique_key_list")).collect() unique_key_list = service_requests_csv_pd_df["Unique Key"].to_list() @@ -68,7 +68,7 @@ def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts): assert set(result_list[0]) == set(unique_key_list) -def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Averages across a column for entire table""" daft_df = daft_df.repartition(repartition_nparts).agg( [ @@ -92,7 +92,7 @@ def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_mean") -def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): """Sums across an entire column for the entire table filtered by a certain condition""" daft_df = ( daft_df.repartition(repartition_nparts) @@ -119,7 +119,7 @@ def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts): pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """Sums across groups""" daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).sum(col("Unique Key")) service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).sum("Unique Key").reset_index() @@ -134,7 +134,7 @@ def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """Sums across groups""" daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).mean(col("Unique Key")) service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).mean("Unique Key").reset_index() @@ -149,7 +149,7 @@ def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, k pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """count across groups""" daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).count() service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).count().reset_index() @@ -167,7 +167,7 @@ def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """min across groups""" daft_df = ( daft_df.repartition(repartition_nparts) @@ -188,7 +188,7 @@ def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"), ], ) -def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """max across groups""" daft_df = ( daft_df.repartition(repartition_nparts) @@ -209,7 +209,7 @@ def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke pytest.param(["Borough", "Complaint Type"], id="NumGroupSortKeys:2"), ], ) -def test_sum_groupby_sorted(daft_df, service_requests_csv_pd_df, repartition_nparts, keys): +def test_sum_groupby_sorted(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner): """Test sorting after a groupby""" daft_df = ( daft_df.repartition(repartition_nparts) diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 1667204fbb..29d33e64c8 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_agg_global(repartition_nparts): +def test_agg_global(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [1, 2, 3], @@ -46,7 +46,7 @@ def test_agg_global(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_agg_global_all_null(repartition_nparts): +def test_agg_global_all_null(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2, 3], @@ -82,7 +82,7 @@ def test_agg_global_all_null(repartition_nparts): assert pa.Table.from_pydict(daft_cols) == pa.Table.from_pydict(expected) -def test_agg_global_empty(): +def test_agg_global_empty(use_new_planner): daft_df = daft.from_pydict( { "id": [0], @@ -119,7 +119,7 @@ def test_agg_global_empty(): @pytest.mark.parametrize("repartition_nparts", [1, 2, 7]) -def test_agg_groupby(repartition_nparts): +def test_agg_groupby(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "group": [1, 1, 1, 2, 2, 2], @@ -164,7 +164,7 @@ def test_agg_groupby(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) -def test_agg_groupby_all_null(repartition_nparts): +def test_agg_groupby_all_null(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2, 3, 4], @@ -203,7 +203,7 @@ def test_agg_groupby_all_null(repartition_nparts): ) -def test_agg_groupby_null_type_column(): +def test_agg_groupby_null_type_column(use_new_planner): daft_df = daft.from_pydict( { "id": [1, 2, 3, 4], @@ -222,7 +222,7 @@ def test_agg_groupby_null_type_column(): @pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) -def test_null_groupby_keys(repartition_nparts): +def test_null_groupby_keys(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2, 3, 4], @@ -252,7 +252,7 @@ def test_null_groupby_keys(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_all_null_groupby_keys(repartition_nparts): +def test_all_null_groupby_keys(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2], @@ -281,7 +281,7 @@ def test_all_null_groupby_keys(repartition_nparts): assert set(daft_cols["list"][0]) == {1, 2, 3} -def test_null_type_column_groupby_keys(): +def test_null_type_column_groupby_keys(use_new_planner): daft_df = daft.from_pydict( { "id": [0, 1, 2], @@ -294,7 +294,7 @@ def test_null_type_column_groupby_keys(): daft_df.groupby(col("group")) -def test_agg_groupby_empty(): +def test_agg_groupby_empty(use_new_planner): daft_df = daft.from_pydict( { "id": [0], @@ -337,7 +337,7 @@ class CustomObject: val: int -def test_agg_pyobjects(): +def test_agg_pyobjects(use_new_planner): objects = [CustomObject(val=0), None, CustomObject(val=1)] df = daft.from_pydict({"objs": objects}) df = df.into_partitions(2) @@ -354,7 +354,7 @@ def test_agg_pyobjects(): assert res["list"] == [objects] -def test_groupby_agg_pyobjects(): +def test_groupby_agg_pyobjects(use_new_planner): objects = [CustomObject(val=0), CustomObject(val=1), None, None, CustomObject(val=2)] df = daft.from_pydict({"objects": objects, "groups": [1, 2, 1, 2, 1]}) df = df.into_partitions(2)