From aef599945a42bffc8b7fdc6be653638802964cb9 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 6 Aug 2024 10:38:31 -0700 Subject: [PATCH] [FEAT] Enable broadcast strategy on anti and semi joins (#2621) --- daft/dataframe/dataframe.py | 2 -- src/daft-plan/src/physical_planner/translate.rs | 12 ++---------- tests/dataframe/test_joins.py | 4 ---- 3 files changed, 2 insertions(+), 16 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index d1348b9625..f080303501 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1497,8 +1497,6 @@ def join( raise ValueError("Sort merge join only supports inner joins") elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer: raise ValueError("Broadcast join does not support outer joins") - elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Anti: - raise ValueError("Broadcast join does not support Anti joins") left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,)) right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,)) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index ec063649c4..b62a762db8 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -548,16 +548,8 @@ pub(super) fn translate_single_logical_node( "Broadcast join does not support outer joins.".to_string(), )); } - (JoinType::Anti, _) => { - return Err(common_error::DaftError::ValueError( - "Broadcast join does not support anti joins.".to_string(), - )); - } - (JoinType::Semi, _) => { - return Err(common_error::DaftError::ValueError( - "Broadcast join does not support semi joins.".to_string(), - )); - } + (JoinType::Anti, _) => true, + (JoinType::Semi, _) => true, }; if is_swapped { diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index e2e8cc962e..56314f4549 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -13,10 +13,6 @@ def skip_invalid_join_strategies(join_strategy, join_type): pytest.skip("Sort merge currently only supports inner joins") elif join_strategy == "broadcast" and join_type == "outer": pytest.skip("Broadcast join does not support outer joins") - elif join_strategy == "broadcast" and join_type == "anti": - pytest.skip("Broadcast join does not support anti joins") - elif join_strategy == "broadcast" and join_type == "semi": - pytest.skip("Broadcast join does not support semi joins") def test_invalid_join_strategies(make_df):