diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index ed807f7d2b..ab44240f5f 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -8,7 +8,7 @@ use snafu::ResultExt; use crate::logical_plan::{CreationSnafu, Result}; use crate::optimization::Transformed; -use crate::{LogicalPlan, ResourceRequest}; +use crate::{LogicalPlan, PartitionScheme, PartitionSpec, ResourceRequest}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Project { @@ -17,6 +17,7 @@ pub struct Project { pub projection: Vec, pub resource_request: ResourceRequest, pub projected_schema: SchemaRef, + pub partition_spec: Arc, } impl Project { @@ -38,14 +39,182 @@ impl Project { .context(CreationSnafu)?; Schema::new(fields).context(CreationSnafu)?.into() }; + let partition_spec = + Self::translate_partition_spec(factored_input.partition_spec(), &factored_projection); Ok(Self { input: factored_input, projection: factored_projection, resource_request, projected_schema, + partition_spec, }) } + pub fn multiline_display(&self) -> Vec { + vec![ + format!( + "Project: {}", + self.projection + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", ") + ), + format!("Partition spec = {:?}", self.partition_spec), + ] + } + + fn translate_partition_spec( + input_pspec: Arc, + projection: &Vec, + ) -> Arc { + // Given an input partition spec, and a new projection, + // produce the new partition spec. + + use crate::PartitionScheme::*; + match input_pspec.scheme { + // If the scheme is vacuous, the result partiiton spec is the same. + Random | Unknown => input_pspec.clone(), + // Otherwise, need to reevaluate the partition scheme for each expression. + Range | Hash => { + // See what columns the projection directly translates into new columns. + let mut old_colname_to_new_colname = IndexMap::new(); + for expr in projection { + if let Some(oldname) = expr.input_mapping() { + let newname = expr.name().unwrap().to_string(); + // Add the oldname -> newname mapping, + // but don't overwrite any existing identity mappings (e.g. "a" -> "a"). + if old_colname_to_new_colname.get(&oldname) != Some(&oldname) { + old_colname_to_new_colname.insert(oldname, newname); + } + } + } + + // Then, see if we can fully translate the partition spec. + let maybe_new_pspec = input_pspec + .by + .as_ref() + .unwrap() + .iter() + .map(|e| Self::translate_partition_spec_expr(e, &old_colname_to_new_colname)) + .collect::, _>>(); + maybe_new_pspec.map_or_else( + |()| { + PartitionSpec::new_internal( + PartitionScheme::Unknown, + input_pspec.num_partitions, + None, + ) + .into() + }, + |new_pspec: Vec| { + PartitionSpec::new_internal( + input_pspec.scheme.clone(), + input_pspec.num_partitions, + Some(new_pspec), + ) + .into() + }, + ) + } + } + } + + fn translate_partition_spec_expr( + pspec_expr: &Expr, + old_colname_to_new_colname: &IndexMap, + ) -> std::result::Result { + // Given a single expression of an input partition spec, + // translate it to a new expression in the given projection. + // Returns: + // - Ok(expr) with expr being the translation, or + // - Err(()) if no translation is possible in the new projection. + + match pspec_expr { + Expr::Column(name) => match old_colname_to_new_colname.get(name.as_ref()) { + Some(newname) => Ok(Expr::Column(newname.as_str().into())), + None => Err(()), + }, + Expr::Literal(_) => Ok(pspec_expr.clone()), + Expr::Alias(child, name) => { + let newchild = Self::translate_partition_spec_expr( + child.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::Alias(newchild.into(), name.clone())) + } + Expr::BinaryOp { op, left, right } => { + let newleft = + Self::translate_partition_spec_expr(left.as_ref(), old_colname_to_new_colname)?; + let newright = Self::translate_partition_spec_expr( + right.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::BinaryOp { + op: *op, + left: newleft.into(), + right: newright.into(), + }) + } + Expr::Cast(child, dtype) => { + let newchild = Self::translate_partition_spec_expr( + child.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::Cast(newchild.into(), dtype.clone())) + } + Expr::Function { func, inputs } => { + let new_inputs = inputs + .iter() + .map(|e| Self::translate_partition_spec_expr(e, old_colname_to_new_colname)) + .collect::, _>>()?; + Ok(Expr::Function { + func: func.clone(), + inputs: new_inputs, + }) + } + Expr::Not(child) => { + let newchild = Self::translate_partition_spec_expr( + child.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::Not(newchild.into())) + } + Expr::IsNull(child) => { + let newchild = Self::translate_partition_spec_expr( + child.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::IsNull(newchild.into())) + } + Expr::IfElse { + if_true, + if_false, + predicate, + } => { + let newtrue = Self::translate_partition_spec_expr( + if_true.as_ref(), + old_colname_to_new_colname, + )?; + let newfalse = Self::translate_partition_spec_expr( + if_false.as_ref(), + old_colname_to_new_colname, + )?; + let newpred = Self::translate_partition_spec_expr( + predicate.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::IfElse { + if_true: newtrue.into(), + if_false: newfalse.into(), + predicate: newpred.into(), + }) + } + // Cannot have agg exprs in partition specs. + Expr::Agg(_) => Err(()), + } + } + fn try_factor_subexpressions( input: Arc, projection: Vec, @@ -345,9 +514,11 @@ fn replace_column_with_semantic_id_aggexpr( mod tests { use common_error::DaftResult; use daft_core::{datatypes::Field, DataType}; - use daft_dsl::{binary_op, col, lit, Operator}; + use daft_dsl::{binary_op, col, lit, Expr, Operator}; - use crate::{logical_ops::Project, test::dummy_scan_node, LogicalPlan}; + use crate::{ + logical_ops::Project, test::dummy_scan_node, LogicalPlan, PartitionScheme, PartitionSpec, + }; /// Test that nested common subexpressions are correctly split /// into multiple levels of projections. @@ -456,4 +627,107 @@ mod tests { Ok(()) } + + /// Test that projections preserving column inputs, even through aliasing, + /// do not destroy the partition spec. + #[test] + fn test_partition_spec_preserving() -> DaftResult<()> { + let source = dummy_scan_node(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + Field::new("c", DataType::Int64), + ]) + .repartition( + 3, + vec![Expr::Column("a".into()), Expr::Column("b".into())], + PartitionScheme::Hash, + )? + .build(); + + let expressions = vec![ + (col("a") % lit(2)), // this is now "a" + col("b"), + col("a").alias("aa"), + ]; + + let result_projection = Project::try_new(source, expressions, Default::default())?; + + let expected_pspec = + PartitionSpec::new_internal(PartitionScheme::Hash, 3, Some(vec![col("aa"), col("b")])); + + assert_eq!( + expected_pspec, + result_projection.partition_spec.as_ref().clone() + ); + + Ok(()) + } + + /// Test that projections destroying even a single column input from the partition spec + /// destroys the entire partition spec. + #[test] + fn test_partition_spec_destroying() -> DaftResult<()> { + let source = dummy_scan_node(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + Field::new("c", DataType::Int64), + ]) + .repartition( + 3, + vec![Expr::Column("a".into()), Expr::Column("b".into())], + PartitionScheme::Hash, + )? + .build(); + + let expected_pspec = PartitionSpec::new_internal(PartitionScheme::Unknown, 3, None); + + let test_cases = vec![ + vec![col("a"), col("c").alias("b")], // original "b" is gone even though "b" is present + vec![col("b")], // original "a" dropped + vec![col("a") % lit(2), col("b")], // original "a" gone + vec![col("c")], // everything gone + ]; + + for projection in test_cases { + let result_projection = + Project::try_new(source.clone(), projection, Default::default())?; + assert_eq!( + expected_pspec, + result_projection.partition_spec.as_ref().clone() + ); + } + + Ok(()) + } + + /// Test that new partition specs favor existing instead of new names. + /// i.e. ("a", "a" as "b") remains partitioned by "a", not "b" + #[test] + fn test_partition_spec_prefer_existing_names() -> DaftResult<()> { + let source = dummy_scan_node(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + Field::new("c", DataType::Int64), + ]) + .repartition( + 3, + vec![Expr::Column("a".into()), Expr::Column("b".into())], + PartitionScheme::Hash, + )? + .build(); + + let expressions = vec![col("a").alias("y"), col("a"), col("a").alias("z"), col("b")]; + + let result_projection = Project::try_new(source, expressions, Default::default())?; + + let expected_pspec = + PartitionSpec::new_internal(PartitionScheme::Hash, 3, Some(vec![col("a"), col("b")])); + + assert_eq!( + expected_pspec, + result_projection.partition_spec.as_ref().clone() + ); + + Ok(()) + } } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 9136bc4bf6..d12ab418be 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -125,7 +125,7 @@ impl LogicalPlan { pub fn partition_spec(&self) -> Arc { match self { Self::Source(Source { partition_spec, .. }) => partition_spec.clone(), - Self::Project(Project { input, .. }) => input.partition_spec(), + Self::Project(Project { partition_spec, .. }) => partition_spec.clone(), Self::Filter(Filter { input, .. }) => input.partition_spec(), Self::Limit(Limit { input, .. }) => input.partition_spec(), Self::Explode(Explode { input, .. }) => input.partition_spec(), @@ -283,16 +283,7 @@ impl LogicalPlan { pub fn multiline_display(&self) -> Vec { match self { Self::Source(source) => source.multiline_display(), - Self::Project(Project { projection, .. }) => { - vec![format!( - "Project: {}", - projection - .iter() - .map(|e| e.to_string()) - .collect::>() - .join(", ") - )] - } + Self::Project(projection) => projection.multiline_display(), Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], Self::Explode(Explode { to_explode, .. }) => { diff --git a/src/daft-plan/src/optimization/optimizer.rs b/src/daft-plan/src/optimization/optimizer.rs index 346c93d523..7e1cc7d9e1 100644 --- a/src/daft-plan/src/optimization/optimizer.rs +++ b/src/daft-plan/src/optimization/optimizer.rs @@ -535,7 +535,7 @@ mod tests { assert_eq!(pass_count, 6); let expected = "\ Filter: [[[col(a) < lit(2)] | lit(false)] | lit(false)] & lit(true)\ - \n Project: col(a) + lit(3) AS c, col(a) + lit(1), col(a) + lit(2) AS b\ + \n Project: col(a) + lit(3) AS c, col(a) + lit(1), col(a) + lit(2) AS b, Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ \n Source: Json, File paths = [/foo], File schema = a (Int64), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64)"; assert_eq!(opt_plan.repr_indent(), expected); Ok(()) diff --git a/src/daft-plan/src/optimization/rules/push_down_filter.rs b/src/daft-plan/src/optimization/rules/push_down_filter.rs index 9c272f32c5..8469abe5c6 100644 --- a/src/daft-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/optimization/rules/push_down_filter.rs @@ -277,7 +277,7 @@ mod tests { .filter(col("a").lt(&lit(2)))? .build(); let expected = "\ - Project: col(a)\ + Project: col(a), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ \n Filter: col(a) < lit(2)\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8)"; assert_optimized_plan_eq(plan, expected)?; @@ -295,7 +295,7 @@ mod tests { .filter(col("a").lt(&lit(2)).and(&col("b").eq(&lit("foo"))))? .build(); let expected = "\ - Project: col(a), col(b)\ + Project: col(a), col(b), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ \n Filter: [col(a) < lit(2)] & [col(b) == lit(\"foo\")]\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8)"; assert_optimized_plan_eq(plan, expected)?; @@ -316,7 +316,7 @@ mod tests { // Filter should NOT commute with Project, since this would involve redundant computation. let expected = "\ Filter: col(a) < lit(2)\ - \n Project: col(a) + lit(1)\ + \n Project: col(a) + lit(1), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8)"; assert_optimized_plan_eq(plan, expected)?; Ok(()) @@ -336,7 +336,7 @@ mod tests { .filter(col("a").lt(&lit(2)))? .build(); let expected = "\ - Project: col(a) + lit(1)\ + Project: col(a) + lit(1), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ \n Filter: [col(a) + lit(1)] < lit(2)\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8)"; assert_optimized_plan_eq(plan, expected)?; diff --git a/src/daft-plan/src/optimization/rules/push_down_limit.rs b/src/daft-plan/src/optimization/rules/push_down_limit.rs index f4f2169af9..233d0390d9 100644 --- a/src/daft-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-plan/src/optimization/rules/push_down_limit.rs @@ -233,7 +233,7 @@ mod tests { .limit(5)? .build(); let expected = "\ - Project: col(a)\ + Project: col(a), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Utf8), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Utf8), Limit = 5"; assert_optimized_plan_eq(plan, expected)?; Ok(())