diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index b9a310d09b..de66377a79 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -290,15 +290,18 @@ pub(super) fn translate_single_logical_node( let (first_stage_aggs, second_stage_aggs, final_exprs) = populate_aggregation_stages(aggregations, &schema, groupby); - let first_stage_agg = if first_stage_aggs.is_empty() { - input_physical + let (first_stage_agg, groupby) = if first_stage_aggs.is_empty() { + (input_physical, groupby.clone()) } else { - PhysicalPlan::Aggregate(Aggregate::new( - input_physical, - first_stage_aggs.values().cloned().collect(), - groupby.clone(), - )) - .arced() + ( + PhysicalPlan::Aggregate(Aggregate::new( + input_physical, + first_stage_aggs.values().cloned().collect(), + groupby.clone(), + )) + .arced(), + groupby.iter().map(|e| col(e.name())).collect(), + ) }; let gather_plan = if groupby.is_empty() { PhysicalPlan::Coalesce(Coalesce::new( @@ -323,7 +326,7 @@ pub(super) fn translate_single_logical_node( let second_stage_agg = PhysicalPlan::Aggregate(Aggregate::new( gather_plan, second_stage_aggs.values().cloned().collect(), - groupby.clone(), + groupby, )); PhysicalPlan::Project(Project::try_new(second_stage_agg.into(), final_exprs)?) @@ -771,7 +774,7 @@ pub fn populate_aggregation_stages( let mut first_stage_aggs: HashMap, AggExpr> = HashMap::new(); let mut second_stage_aggs: HashMap, AggExpr> = HashMap::new(); // Project the aggregation results to their final output names - let mut final_exprs: Vec = group_by.to_vec(); + let mut final_exprs: Vec = group_by.iter().map(|e| col(e.name())).collect(); for agg_expr in aggregations { let output_name = agg_expr.name(); diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 1f68b68a2a..74fe889ce0 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -314,6 +314,48 @@ def test_agg_groupby_empty(make_df): ) +@pytest.mark.parametrize("repartition_nparts", [1, 2, 7]) +def test_agg_groupby_with_alias(make_df, repartition_nparts): + daft_df = make_df( + { + "group": [1, 1, 1, 2, 2, 2], + "values": [1, None, 2, 2, None, 4], + }, + repartition=repartition_nparts, + ) + daft_df = daft_df.groupby(daft_df["group"].alias("group_alias")).agg( + [ + col("values").sum().alias("sum"), + col("values").mean().alias("mean"), + col("values").min().alias("min"), + col("values").max().alias("max"), + col("values").count().alias("count"), + col("values").agg_list().alias("list"), + ] + ) + expected = { + "group_alias": [1, 2], + "sum": [3, 6], + "mean": [1.5, 3], + "min": [1, 2], + "max": [2, 4], + "count": [2, 2], + "list": [[1, None, 2], [2, None, 4]], + } + + daft_df.collect() + daft_cols = daft_df.to_pydict() + res_list = daft_cols.pop("list") + exp_list = expected.pop("list") + + assert sort_arrow_table(pa.Table.from_pydict(daft_cols), "group_alias") == sort_arrow_table( + pa.Table.from_pydict(expected), "group_alias" + ) + + arg_sort = np.argsort(daft_cols["group_alias"]) + assert freeze([list(map(set, res_list))[i] for i in arg_sort]) == freeze(list(map(set, exp_list))) + + @dataclass class CustomObject: val: int diff --git a/tests/dataframe/test_approx_percentiles_aggregations.py b/tests/dataframe/test_approx_percentiles_aggregations.py index d9b9febd48..d64f1a2381 100644 --- a/tests/dataframe/test_approx_percentiles_aggregations.py +++ b/tests/dataframe/test_approx_percentiles_aggregations.py @@ -146,3 +146,35 @@ def test_approx_percentiles_groupby_all_nulls(make_df, repartition_nparts, perce ) daft_cols = daft_df.to_pydict() assert daft_cols["percentiles"] == expected + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 5]) +@pytest.mark.parametrize( + "percentiles_expected", + [ + (0.5, [2.0, 2.0]), + ([0.5], [[2.0], [2.0]]), + ([0.5, 0.5], [[2.0, 2.0], [2.0, 2.0]]), + ], +) +def test_approx_percentiles_groupby_with_alias(make_df, repartition_nparts, percentiles_expected): + percentiles, expected = percentiles_expected + daft_df = make_df( + { + "id": [1, 1, 1, 2], + "values": [1, 2, 3, 2], + }, + repartition=repartition_nparts, + ) + daft_df = daft_df.groupby(daft_df["id"].alias("id_alias")).agg( + [ + col("values").approx_percentiles(percentiles).alias("percentiles"), + ] + ) + daft_cols = daft_df.to_pydict() + pd.testing.assert_series_equal( + pd.Series(daft_cols["percentiles"]), + pd.Series(expected), + check_exact=False, + rtol=0.02, + ) diff --git a/tests/dataframe/test_map_groups.py b/tests/dataframe/test_map_groups.py index 0d5fe7ef86..4f0f2e29ec 100644 --- a/tests/dataframe/test_map_groups.py +++ b/tests/dataframe/test_map_groups.py @@ -140,3 +140,38 @@ def udf(data): daft_cols = daft_df.to_pydict() assert daft_cols == expected + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_map_groups_with_alias(make_df, repartition_nparts): + daft_df = make_df( + { + "group": [1, 1, 2], + "a": [1, 3, 3], + "b": [5, 6, 7], + }, + repartition=repartition_nparts, + ) + + @daft.udf(return_dtype=daft.DataType.list(daft.DataType.float64())) + def udf(a, b): + a, b = a.to_pylist(), b.to_pylist() + res = [] + for i in range(len(a)): + res.append(a[i] / sum(a) + b[i]) + res.sort() + return [res] + + daft_df = ( + daft_df.groupby(daft_df["group"].alias("group_alias")) + .map_groups(udf(daft_df["a"], daft_df["b"])) + .sort("group_alias", desc=False) + ) + expected = { + "group_alias": [1, 2], + "a": [[5.25, 6.75], [8.0]], + } + + daft_cols = daft_df.to_pydict() + + assert daft_cols == expected