diff --git a/crates/polars-lazy/src/physical_plan/streaming/checks.rs b/crates/polars-lazy/src/physical_plan/streaming/checks.rs index 94f1b98ff2850..5d1841237bf7e 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/checks.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/checks.rs @@ -28,7 +28,14 @@ pub(super) fn streamable_join(args: &JoinArgs) -> bool { let supported = match args.how { #[cfg(feature = "cross_join")] JoinType::Cross => true, - JoinType::Inner | JoinType::Left | JoinType::Outer { .. } => true, + JoinType::Inner | JoinType::Left => { + // no-coalescing not yet supported in streaming + matches!( + args.coalesce, + JoinCoalesce::JoinSpecific | JoinCoalesce::CoalesceColumns + ) + }, + JoinType::Outer { .. } => true, _ => false, }; supported && !args.validation.needs_checks() diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 6a29e2b28c3a0..0ac9de8976c79 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -117,6 +117,8 @@ pub trait DataFrameJoinOps: IntoDf { let left_df = self.to_df(); args.validation.is_valid_join(&args.how)?; + let should_coalesce = args.coalesce.coalesce(&args.how); + #[cfg(feature = "cross_join")] if let JoinType::Cross = args.how { return left_df.cross_join(other, args.suffix.as_deref(), args.slice); @@ -202,13 +204,12 @@ pub trait DataFrameJoinOps: IntoDf { if selected_left.len() == 1 { let s_left = &selected_left[0]; let s_right = &selected_right[0]; + let drop_names: Option<&[&str]> = if should_coalesce { None } else { Some(&[]) }; return match args.how { - JoinType::Inner => { - left_df._inner_join_from_series(other, s_left, s_right, args, _verbose, None) - }, - JoinType::Left => { - left_df._left_join_from_series(other, s_left, s_right, args, _verbose, None) - }, + JoinType::Inner => left_df + ._inner_join_from_series(other, s_left, s_right, args, _verbose, drop_names), + JoinType::Left => left_df + ._left_join_from_series(other, s_left, s_right, args, _verbose, drop_names), JoinType::Outer => left_df._outer_join_from_series(other, s_left, s_right, args), #[cfg(feature = "semi_anti_join")] JoinType::Anti => left_df._semi_anti_join_from_series( @@ -265,7 +266,12 @@ pub trait DataFrameJoinOps: IntoDf { let lhs_keys = prepare_keys_multiple(&selected_left, args.join_nulls)?.into_series(); let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?.into_series(); - let names_right = selected_right.iter().map(|s| s.name()).collect::>(); + + let drop_names = if should_coalesce { + Some(selected_right.iter().map(|s| s.name()).collect::>()) + } else { + Some(vec![]) + }; // Multiple keys. match args.how { @@ -278,16 +284,15 @@ pub trait DataFrameJoinOps: IntoDf { }, JoinType::Outer => { let names_left = selected_left.iter().map(|s| s.name()).collect::>(); - let coalesce = args.coalesce; args.coalesce = JoinCoalesce::KeepColumns; let suffix = args.suffix.clone(); let out = left_df._outer_join_from_series(other, &lhs_keys, &rhs_keys, args); - if coalesce.coalesce(&JoinType::Outer) { + if should_coalesce { Ok(_coalesce_outer_join( out?, &names_left, - &names_right, + drop_names.as_ref().unwrap(), suffix.as_deref(), left_df, )) @@ -301,7 +306,7 @@ pub trait DataFrameJoinOps: IntoDf { &rhs_keys, args, _verbose, - Some(&names_right), + drop_names.as_deref(), ), JoinType::Left => left_df._left_join_from_series( other, @@ -309,7 +314,7 @@ pub trait DataFrameJoinOps: IntoDf { &rhs_keys, args, _verbose, - Some(&names_right), + drop_names.as_deref(), ), #[cfg(feature = "semi_anti_join")] JoinType::Anti | JoinType::Semi => self._join_impl( diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 05c067714b64b..837bd47b0d959 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -6182,6 +6182,7 @@ def join( suffix: str = "_right", validate: JoinValidation = "m:m", join_nulls: bool = False, + coalesce: bool | None = None, ) -> DataFrame: """ Join in SQL-like fashion. @@ -6236,6 +6237,11 @@ def join( - This is currently not supported the streaming engine. join_nulls Join on null values. By default null values will never produce matches. + coalesce + Coalescing behavior (merging of join columns). + - None: -> join specific. + - True: -> Always coalesce join columns. + - False: -> Never coalesce join columns. Returns ------- @@ -6336,6 +6342,7 @@ def join( suffix=suffix, validate=validate, join_nulls=join_nulls, + coalesce=coalesce, ) .collect(_eager=True) ) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 24ca7fa32cd4a..d0ee96efdc655 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -3711,6 +3711,7 @@ def join_asof( Force the physical plan to evaluate the computation of both DataFrames up to the join in parallel. + Examples -------- >>> from datetime import datetime @@ -3814,6 +3815,7 @@ def join( suffix: str = "_right", validate: JoinValidation = "m:m", join_nulls: bool = False, + coalesce: bool | None = None, allow_parallel: bool = True, force_parallel: bool = False, ) -> Self: @@ -3837,8 +3839,6 @@ def join( right table * *outer* Returns all rows when there is a match in either left or right table - * *outer_coalesce* - Same as 'outer', but coalesces the key columns * *cross* Returns the Cartesian product of rows from both tables * *semi* @@ -3871,6 +3871,11 @@ def join( - This is currently not supported the streaming engine. join_nulls Join on null values. By default null values will never produce matches. + coalesce + Coalescing behavior (merging of join columns). + - None: -> join specific. + - True: -> Always coalesce join columns. + - False: -> Never coalesce join columns. allow_parallel Allow the physical plan to optionally evaluate the computation of both DataFrames up to the join in parallel. @@ -3980,7 +3985,6 @@ def join( msg = "must specify `on` OR `left_on` and `right_on`" raise ValueError(msg) - coalesce = None if how == "outer_coalesce": coalesce = True diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 201370fa8d7a4..f06a40f8d103b 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -960,3 +960,33 @@ def test_cross_join_slice_pushdown() -> None: }, schema={"x": pl.UInt16, "x_": pl.UInt16}, ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("how", ["left", "inner"]) +@typing.no_type_check +def test_join_coalesce(how: str) -> None: + a = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}) + b = pl.LazyFrame( + { + "a": [1, 2, 1, 2], + "b": [5, 7, 8, 9], + "c": [1, 2, 1, 2], + } + ) + + how = "inner" + q = a.join(b, on="a", coalesce=False, how=how) + out = q.collect() + assert q.schema == out.schema + assert out.columns == ["a", "b", "a_right", "b_right", "c"] + + q = a.join(b, on=["a", "b"], coalesce=False, how=how) + out = q.collect() + assert q.schema == out.schema + assert out.columns == ["a", "b", "a_right", "b_right", "c"] + + q = a.join(b, on=["a", "b"], coalesce=True, how=how) + out = q.collect() + assert q.schema == out.schema + assert out.columns == ["a", "b", "c"]