diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index f08646959e3f..c85a547e1c55 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -32,6 +32,7 @@ pub use ndjson::*; pub use parquet::*; use polars_core::prelude::*; use polars_io::RowIndex; +use polars_ops::frame::JoinCoalesce; pub use polars_plan::frame::{AllowedOptimizations, OptState}; use polars_plan::global::FETCH_ROWS; use smartstring::alias::String as SmartString; @@ -1163,7 +1164,7 @@ impl LazyFrame { other, [left_on.into()], [right_on.into()], - JoinArgs::new(JoinType::Outer { coalesce: false }), + JoinArgs::new(JoinType::Outer), ) } @@ -1869,6 +1870,7 @@ pub struct JoinBuilder { force_parallel: bool, suffix: Option, validation: JoinValidation, + coalesce: JoinCoalesce, join_nulls: bool, } impl JoinBuilder { @@ -1885,6 +1887,7 @@ impl JoinBuilder { join_nulls: false, suffix: None, validation: Default::default(), + coalesce: Default::default(), } } @@ -1956,6 +1959,12 @@ impl JoinBuilder { self } + /// Whether to coalesce join columns. + pub fn coalesce(mut self, coalesce: JoinCoalesce) -> Self { + self.coalesce = coalesce; + self + } + /// Finish builder pub fn finish(self) -> LazyFrame { let mut opt_state = self.lf.opt_state; @@ -1970,6 +1979,7 @@ impl JoinBuilder { suffix: self.suffix, slice: None, join_nulls: self.join_nulls, + ..Default::default() }; let lp = self diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index a25a015a1e42..b45185d6909d 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -1,3 +1,5 @@ +use polars_ops::frame::JoinCoalesce; + use super::*; fn get_csv_file() -> LazyFrame { @@ -294,7 +296,8 @@ fn test_streaming_partial() -> PolarsResult<()> { .left_on([col("a")]) .right_on([col("a")]) .suffix("_foo") - .how(JoinType::Outer { coalesce: true }) + .how(JoinType::Outer) + .coalesce(JoinCoalesce::CoalesceColumns) .finish(); let q = q.left_join( diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 42339e9425d1..6f53fe881e4b 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -26,6 +26,36 @@ pub struct JoinArgs { pub suffix: Option, pub slice: Option<(i64, usize)>, pub join_nulls: bool, + pub coalesce: JoinCoalesce, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum JoinCoalesce { + #[default] + JoinSpecific, + CoalesceColumns, + KeepColumns, +} + +impl JoinCoalesce { + pub fn coalesce(&self, join_type: &JoinType) -> bool { + use JoinCoalesce::*; + use JoinType::*; + match join_type { + Left | Inner => { + matches!(self, JoinSpecific | CoalesceColumns) + }, + Outer { .. } => { + matches!(self, CoalesceColumns) + }, + #[cfg(feature = "asof_join")] + AsOf(_) => false, + Cross => false, + #[cfg(feature = "semi_anti_join")] + Semi | Anti => false, + } + } } impl Default for JoinArgs { @@ -36,6 +66,7 @@ impl Default for JoinArgs { suffix: None, slice: None, join_nulls: false, + coalesce: Default::default(), } } } @@ -48,9 +79,15 @@ impl JoinArgs { suffix: None, slice: None, join_nulls: false, + coalesce: Default::default(), } } + pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self { + self.coalesce = coalesce; + self + } + pub fn suffix(&self) -> &str { self.suffix.as_deref().unwrap_or("_right") } @@ -61,9 +98,7 @@ impl JoinArgs { pub enum JoinType { Left, Inner, - Outer { - coalesce: bool, - }, + Outer, #[cfg(feature = "asof_join")] AsOf(AsOfOptions), Cross, @@ -73,18 +108,6 @@ pub enum JoinType { Anti, } -impl JoinType { - pub fn merges_join_keys(&self) -> bool { - match self { - Self::Outer { coalesce } => *coalesce, - // Merges them if they are equal - #[cfg(feature = "asof_join")] - Self::AsOf(_) => false, - _ => true, - } - } -} - impl From for JoinArgs { fn from(value: JoinType) -> Self { JoinArgs::new(value) diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index f07667130cc5..9232b0ffeefe 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -267,9 +267,7 @@ pub trait JoinDispatch: IntoDf { || unsafe { other.take_unchecked(&idx_ca_r) }, ); - let JoinType::Outer { coalesce } = args.how else { - unreachable!() - }; + let coalesce = args.coalesce.coalesce(&JoinType::Outer); let out = _finish_join(df_left, df_right, args.suffix.as_deref()); if coalesce { Ok(_coalesce_outer_join( diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index ccc4c72184bd..ec3e48a2e9de 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -290,9 +290,13 @@ pub trait DataFrameJoinOps: IntoDf { // SAFETY: join indices are known to be in bounds || unsafe { left_df._create_left_df_from_slice(join_idx_left, false, !swap) }, || unsafe { - // remove join columns - remove_selected(other, &selected_right) - ._take_unchecked_slice(join_idx_right, true) + if args.coalesce.coalesce(&args.how) { + // remove join columns + Cow::Owned(remove_selected(other, &selected_right)) + } else { + Cow::Borrowed(other) + } + ._take_unchecked_slice(join_idx_right, true) }, ); _finish_join(df_left, df_right, args.suffix.as_deref()) @@ -306,7 +310,14 @@ pub trait DataFrameJoinOps: IntoDf { } let ids = _left_join_multiple_keys(&mut left, &mut right, None, None, args.join_nulls); - left_df._finish_left_join(ids, &remove_selected(other, &selected_right), args) + let other = if args.coalesce.coalesce(&args.how) { + // remove join columns + Cow::Owned(remove_selected(other, &selected_right)) + } else { + Cow::Borrowed(other) + }; + + left_df._finish_left_join(ids, &other, args) }, JoinType::Outer { .. } => { let df_left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; @@ -330,9 +341,7 @@ pub trait DataFrameJoinOps: IntoDf { || unsafe { other.take_unchecked(&idx_ca_r) }, ); - let JoinType::Outer { coalesce } = args.how else { - unreachable!() - }; + let coalesce = args.coalesce.coalesce(&JoinType::Outer); let names_left = selected_left.iter().map(|s| s.name()).collect::>(); let names_right = selected_right.iter().map(|s| s.name()).collect::>(); let out = _finish_join(df_left, df_right, args.suffix.as_deref()); @@ -459,12 +468,7 @@ pub trait DataFrameJoinOps: IntoDf { I: IntoIterator, S: AsRef, { - self.join( - other, - left_on, - right_on, - JoinArgs::new(JoinType::Outer { coalesce: false }), - ) + self.join(other, left_on, right_on, JoinArgs::new(JoinType::Outer)) } } diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index 864020d1a8a1..1fa7ce58a152 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -5,6 +5,7 @@ use hashbrown::hash_map::RawEntryMut; use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; +use polars_ops::prelude::JoinArgs; use polars_utils::arena::Node; use polars_utils::slice::GetSaferUnchecked; use polars_utils::unitvec; @@ -34,6 +35,7 @@ pub struct GenericBuild { materialized_join_cols: Vec>, suffix: Arc, hb: RandomState, + join_args: JoinArgs, // partitioned tables that will be used for probing // stores the key and the chunk_idx, df_idx of the left table hash_tables: PartitionedMap, @@ -45,7 +47,6 @@ pub struct GenericBuild { // amortize allocations join_columns: Vec, hashes: Vec, - join_type: JoinType, // the join order is swapped to ensure we hash the smaller table swapped: bool, join_nulls: bool, @@ -59,7 +60,7 @@ impl GenericBuild { #[allow(clippy::too_many_arguments)] pub(crate) fn new( suffix: Arc, - join_type: JoinType, + join_args: JoinArgs, swapped: bool, join_columns_left: Arc>>, join_columns_right: Arc>>, @@ -76,7 +77,7 @@ impl GenericBuild { })); GenericBuild { chunks: vec![], - join_type, + join_args, suffix, hb, swapped, @@ -278,7 +279,7 @@ impl Sink for GenericBuild { fn split(&self, _thread_no: usize) -> Box { let mut new = Self::new( self.suffix.clone(), - self.join_type.clone(), + self.join_args.clone(), self.swapped, self.join_columns_left.clone(), self.join_columns_right.clone(), @@ -317,7 +318,7 @@ impl Sink for GenericBuild { let mut hashes = std::mem::take(&mut self.hashes); hashes.clear(); - match self.join_type { + match self.join_args.how { JoinType::Inner | JoinType::Left => { let probe_operator = GenericJoinProbe::new( left_df, @@ -330,13 +331,14 @@ impl Sink for GenericBuild { self.swapped, hashes, context, - self.join_type.clone(), + self.join_args.how.clone(), self.join_nulls, ); self.placeholder.replace(Box::new(probe_operator)); Ok(FinalizedSink::Operator) }, - JoinType::Outer { coalesce } => { + JoinType::Outer => { + let coalesce = self.join_args.coalesce.coalesce(&JoinType::Outer); let probe_operator = GenericOuterJoinProbe::new( left_df, materialized_join_cols, diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index dcb934db76f1..026b767d8758 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -285,12 +285,12 @@ where }; match jt { - join_type @ JoinType::Inner | join_type @ JoinType::Left => { + JoinType::Inner | JoinType::Left => { let (join_columns_left, join_columns_right) = swap_eval(); Box::new(GenericBuild::<()>::new( Arc::from(options.args.suffix()), - join_type.clone(), + options.args.clone(), swapped, join_columns_left, join_columns_right, @@ -317,7 +317,7 @@ where Box::new(GenericBuild::::new( Arc::from(options.args.suffix()), - jt.clone(), + options.args.clone(), swapped, join_columns_left, join_columns_right, diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs index a038a8a53a4e..b5278a3a6481 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs @@ -258,7 +258,8 @@ pub(super) fn process_join( already_added_local_to_local_projected.insert(local_name); } // In outer joins both columns remain. So `add_local=true` also for the right table - let add_local = matches!(options.args.how, JoinType::Outer { coalesce: false }); + let add_local = matches!(options.args.how, JoinType::Outer) + && !options.args.coalesce.coalesce(&options.args.how); for e in &right_on { // In case of outer joins we also add the columns. // But before we do that we must check if the column wasn't already added by the lhs. diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index cef56be58061..1a02cc64e6f4 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -300,11 +300,12 @@ pub(crate) fn det_join_schema( new_schema.with_column(field.name, field.dtype); arena.clear(); } + let coalesces_join_keys = options.args.coalesce.coalesce(&options.args.how); // except in asof joins. Asof joins are not equi-joins // so the columns that are joined on, may have different // values so if the right has a different name, it is added to the schema #[cfg(feature = "asof_join")] - if !options.args.how.merges_join_keys() { + if !coalesces_join_keys { for (left_on, right_on) in left_on.iter().zip(right_on) { let field_left = left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?; @@ -330,9 +331,9 @@ pub(crate) fn det_join_schema( } for (name, dtype) in schema_right.iter() { - if !join_on_right.contains(name.as_str()) // The names that are joined on are merged - || matches!(&options.args.how, JoinType::Outer{coalesce: false}) - // The names are not merged + if !coalesces_join_keys || // The names are not merged + !join_on_right.contains(name.as_str()) + // The names that are joined on are merged { if schema_left.contains(name.as_str()) { #[cfg(feature = "asof_join")] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 9a9963ed3259..d33d16aa9410 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -293,14 +293,9 @@ impl SQLContext { let (r_name, rf) = self.get_table(&tbl.relation)?; lf = match &tbl.join_operator { JoinOperator::CrossJoin => lf.cross_join(rf), - JoinOperator::FullOuter(constraint) => process_join( - lf, - rf, - constraint, - &l_name, - &r_name, - JoinType::Outer { coalesce: false }, - )?, + JoinOperator::FullOuter(constraint) => { + process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Outer)? + }, JoinOperator::Inner(constraint) => { process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)? }, diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 212de7960562..0542e77f96f1 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -119,7 +119,7 @@ fn test_outer_join() -> PolarsResult<()> { &rain, ["days"], ["days"], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), )?; assert_eq!(joined.height(), 5); assert_eq!(joined.column("days")?.sum::().unwrap(), 7); @@ -139,7 +139,7 @@ fn test_outer_join() -> PolarsResult<()> { &df_right, ["a"], ["a"], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), )?; assert_eq!(out.column("c_right")?.null_count(), 1); @@ -254,7 +254,7 @@ fn test_join_multiple_columns() { &df_b, ["a", "b"], ["foo", "bar"], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); assert!(joined_outer_hack @@ -300,11 +300,7 @@ fn test_join_categorical() { assert_eq!(Vec::from(ca), correct_ham); // test dispatch - for jt in [ - JoinType::Left, - JoinType::Inner, - JoinType::Outer { coalesce: true }, - ] { + for jt in [JoinType::Left, JoinType::Inner, JoinType::Outer] { let out = df_a.join(&df_b, ["b"], ["bar"], jt.into()).unwrap(); let out = out.column("b").unwrap(); assert_eq!( @@ -471,7 +467,7 @@ fn test_joins_with_duplicates() -> PolarsResult<()> { &df_right, ["col1"], ["join_col1"], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); @@ -543,7 +539,7 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { &df_right, &["col1", "join_col2"], &["join_col1", "col2"], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); @@ -586,7 +582,7 @@ fn test_join_floats() -> PolarsResult<()> { &df_b, vec!["a", "c"], vec!["foo", "bar"], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), )?; assert_eq!( out.dtypes(), diff --git a/crates/polars/tests/it/joins.rs b/crates/polars/tests/it/joins.rs index 2e5435d1bd2c..80e9c31739b2 100644 --- a/crates/polars/tests/it/joins.rs +++ b/crates/polars/tests/it/joins.rs @@ -23,7 +23,8 @@ fn join_nans_outer() -> PolarsResult<()> { .with(a2) .left_on(vec![col("w"), col("t")]) .right_on(vec![col("w"), col("t")]) - .how(JoinType::Outer { coalesce: true }) + .how(JoinType::Outer) + .coalesce(JoinCoalesce::CoalesceColumns) .join_nulls(true) .finish() .collect()?; diff --git a/crates/polars/tests/it/lazy/projection_queries.rs b/crates/polars/tests/it/lazy/projection_queries.rs index ffe38ad57108..5219717ae0b3 100644 --- a/crates/polars/tests/it/lazy/projection_queries.rs +++ b/crates/polars/tests/it/lazy/projection_queries.rs @@ -54,7 +54,7 @@ fn test_outer_join_with_column_2988() -> PolarsResult<()> { ldf2, [col("key1"), col("key2")], [col("key1"), col("key2")], - JoinType::Outer { coalesce: true }.into(), + JoinArgs::new(JoinType::Outer).with_coalesce(JoinCoalesce::CoalesceColumns), ) .with_columns([col("key1")]) .collect()?; diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 4ec0a5632e46..084e7ae96fd4 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4105,6 +4105,10 @@ def join( msg = "must specify `on` OR `left_on` and `right_on`" raise ValueError(msg) + coalesce = None + if how == "outer_coalesce": + coalesce = True + return self._from_pyldf( self._ldf.join( other._ldf, @@ -4116,6 +4120,7 @@ def join( how, suffix, validate, + coalesce, ) ) diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index 6bbedfc8ba0b..222746beb78d 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -697,8 +697,11 @@ impl FromPyObject<'_> for Wrap { let parsed = match ob.extract::<&str>()? { "inner" => JoinType::Inner, "left" => JoinType::Left, - "outer" => JoinType::Outer{coalesce: false}, - "outer_coalesce" => JoinType::Outer{coalesce: true}, + "outer" => JoinType::Outer, + "outer_coalesce" => { + // TODO! deprecate + JoinType::Outer + }, "semi" => JoinType::Semi, "anti" => JoinType::Anti, #[cfg(feature = "cross_join")] diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index f6fc7679a680..5e89159cdb9d 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -835,7 +835,13 @@ impl PyLazyFrame { how: Wrap, suffix: String, validate: Wrap, + coalesce: Option, ) -> PyResult { + let coalesce = match coalesce { + None => JoinCoalesce::JoinSpecific, + Some(true) => JoinCoalesce::CoalesceColumns, + Some(false) => JoinCoalesce::KeepColumns, + }; let ldf = self.ldf.clone(); let other = other.ldf; let left_on = left_on @@ -856,6 +862,7 @@ impl PyLazyFrame { .force_parallel(force_parallel) .join_nulls(join_nulls) .how(how.0) + .coalesce(coalesce) .validate(validate.0) .suffix(suffix) .finish()