Skip to content

Commit

Permalink
refactor(rust)!: prepare for join coalescing argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 1, 2024
1 parent 758b55a commit 8a3599b
Show file tree
Hide file tree
Showing 16 changed files with 120 additions and 71 deletions.
12 changes: 11 additions & 1 deletion crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub use ndjson::*;
pub use parquet::*;
use polars_core::prelude::*;
use polars_io::RowIndex;
use polars_ops::frame::JoinCoalesce;
pub use polars_plan::frame::{AllowedOptimizations, OptState};
use polars_plan::global::FETCH_ROWS;
use smartstring::alias::String as SmartString;
Expand Down Expand Up @@ -1163,7 +1164,7 @@ impl LazyFrame {
other,
[left_on.into()],
[right_on.into()],
JoinArgs::new(JoinType::Outer { coalesce: false }),
JoinArgs::new(JoinType::Outer),
)
}

Expand Down Expand Up @@ -1869,6 +1870,7 @@ pub struct JoinBuilder {
force_parallel: bool,
suffix: Option<String>,
validation: JoinValidation,
coalesce: JoinCoalesce,
join_nulls: bool,
}
impl JoinBuilder {
Expand All @@ -1885,6 +1887,7 @@ impl JoinBuilder {
join_nulls: false,
suffix: None,
validation: Default::default(),
coalesce: Default::default(),
}
}

Expand Down Expand Up @@ -1956,6 +1959,12 @@ impl JoinBuilder {
self
}

/// Whether to coalesce join columns.
pub fn coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}

/// Finish builder
pub fn finish(self) -> LazyFrame {
let mut opt_state = self.lf.opt_state;
Expand All @@ -1970,6 +1979,7 @@ impl JoinBuilder {
suffix: self.suffix,
slice: None,
join_nulls: self.join_nulls,
..Default::default()
};

let lp = self
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-lazy/src/tests/streaming.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use polars_ops::frame::JoinCoalesce;

use super::*;

fn get_csv_file() -> LazyFrame {
Expand Down Expand Up @@ -294,7 +296,8 @@ fn test_streaming_partial() -> PolarsResult<()> {
.left_on([col("a")])
.right_on([col("a")])
.suffix("_foo")
.how(JoinType::Outer { coalesce: true })
.how(JoinType::Outer)
.coalesce(JoinCoalesce::CoalesceColumns)
.finish();

let q = q.left_join(
Expand Down
53 changes: 38 additions & 15 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@ pub struct JoinArgs {
pub suffix: Option<String>,
pub slice: Option<(i64, usize)>,
pub join_nulls: bool,
pub coalesce: JoinCoalesce,
}

#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinCoalesce {
#[default]
JoinSpecific,
CoalesceColumns,
KeepColumns,
}

impl JoinCoalesce {
pub fn coalesce(&self, join_type: &JoinType) -> bool {
use JoinCoalesce::*;
use JoinType::*;
match join_type {
Left | Inner => {
matches!(self, JoinSpecific | CoalesceColumns)
},
Outer { .. } => {
matches!(self, CoalesceColumns)
},
#[cfg(feature = "asof_join")]
AsOf(_) => false,
Cross => false,
#[cfg(feature = "semi_anti_join")]
Semi | Anti => false,
}
}
}

impl Default for JoinArgs {
Expand All @@ -36,6 +66,7 @@ impl Default for JoinArgs {
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}
}
Expand All @@ -48,9 +79,15 @@ impl JoinArgs {
suffix: None,
slice: None,
join_nulls: false,
coalesce: Default::default(),
}
}

pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {
self.coalesce = coalesce;
self
}

pub fn suffix(&self) -> &str {
self.suffix.as_deref().unwrap_or("_right")
}
Expand All @@ -61,9 +98,7 @@ impl JoinArgs {
pub enum JoinType {
Left,
Inner,
Outer {
coalesce: bool,
},
Outer,
#[cfg(feature = "asof_join")]
AsOf(AsOfOptions),
Cross,
Expand All @@ -73,18 +108,6 @@ pub enum JoinType {
Anti,
}

impl JoinType {
pub fn merges_join_keys(&self) -> bool {
match self {
Self::Outer { coalesce } => *coalesce,
// Merges them if they are equal
#[cfg(feature = "asof_join")]
Self::AsOf(_) => false,
_ => true,
}
}
}

impl From<JoinType> for JoinArgs {
fn from(value: JoinType) -> Self {
JoinArgs::new(value)
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ pub trait JoinDispatch: IntoDf {
|| unsafe { other.take_unchecked(&idx_ca_r) },
);

let JoinType::Outer { coalesce } = args.how else {
unreachable!()
};
let coalesce = args.coalesce.coalesce(&JoinType::Outer);
let out = _finish_join(df_left, df_right, args.suffix.as_deref());
if coalesce {
Ok(_coalesce_outer_join(
Expand Down
30 changes: 17 additions & 13 deletions crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,13 @@ pub trait DataFrameJoinOps: IntoDf {
// SAFETY: join indices are known to be in bounds
|| unsafe { left_df._create_left_df_from_slice(join_idx_left, false, !swap) },
|| unsafe {
// remove join columns
remove_selected(other, &selected_right)
._take_unchecked_slice(join_idx_right, true)
if args.coalesce.coalesce(&args.how) {
// remove join columns
Cow::Owned(remove_selected(other, &selected_right))
} else {
Cow::Borrowed(other)
}
._take_unchecked_slice(join_idx_right, true)
},
);
_finish_join(df_left, df_right, args.suffix.as_deref())
Expand All @@ -306,7 +310,14 @@ pub trait DataFrameJoinOps: IntoDf {
}
let ids =
_left_join_multiple_keys(&mut left, &mut right, None, None, args.join_nulls);
left_df._finish_left_join(ids, &remove_selected(other, &selected_right), args)
let other = if args.coalesce.coalesce(&args.how) {
// remove join columns
Cow::Owned(remove_selected(other, &selected_right))
} else {
Cow::Borrowed(other)
};

left_df._finish_left_join(ids, &other, args)
},
JoinType::Outer { .. } => {
let df_left = unsafe { DataFrame::new_no_checks(selected_left_physical) };
Expand All @@ -330,9 +341,7 @@ pub trait DataFrameJoinOps: IntoDf {
|| unsafe { other.take_unchecked(&idx_ca_r) },
);

let JoinType::Outer { coalesce } = args.how else {
unreachable!()
};
let coalesce = args.coalesce.coalesce(&JoinType::Outer);
let names_left = selected_left.iter().map(|s| s.name()).collect::<Vec<_>>();
let names_right = selected_right.iter().map(|s| s.name()).collect::<Vec<_>>();
let out = _finish_join(df_left, df_right, args.suffix.as_deref());
Expand Down Expand Up @@ -459,12 +468,7 @@ pub trait DataFrameJoinOps: IntoDf {
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.join(
other,
left_on,
right_on,
JoinArgs::new(JoinType::Outer { coalesce: false }),
)
self.join(other, left_on, right_on, JoinArgs::new(JoinType::Outer))
}
}

Expand Down
16 changes: 9 additions & 7 deletions crates/polars-pipe/src/executors/sinks/joins/generic_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use hashbrown::hash_map::RawEntryMut;
use polars_core::export::ahash::RandomState;
use polars_core::prelude::*;
use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked};
use polars_ops::prelude::JoinArgs;
use polars_utils::arena::Node;
use polars_utils::slice::GetSaferUnchecked;
use polars_utils::unitvec;
Expand Down Expand Up @@ -34,6 +35,7 @@ pub struct GenericBuild<K: ExtraPayload> {
materialized_join_cols: Vec<BinaryArray<i64>>,
suffix: Arc<str>,
hb: RandomState,
join_args: JoinArgs,
// partitioned tables that will be used for probing
// stores the key and the chunk_idx, df_idx of the left table
hash_tables: PartitionedMap<K>,
Expand All @@ -45,7 +47,6 @@ pub struct GenericBuild<K: ExtraPayload> {
// amortize allocations
join_columns: Vec<ArrayRef>,
hashes: Vec<u64>,
join_type: JoinType,
// the join order is swapped to ensure we hash the smaller table
swapped: bool,
join_nulls: bool,
Expand All @@ -59,7 +60,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
suffix: Arc<str>,
join_type: JoinType,
join_args: JoinArgs,
swapped: bool,
join_columns_left: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
join_columns_right: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
Expand All @@ -76,7 +77,7 @@ impl<K: ExtraPayload> GenericBuild<K> {
}));
GenericBuild {
chunks: vec![],
join_type,
join_args,
suffix,
hb,
swapped,
Expand Down Expand Up @@ -278,7 +279,7 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
fn split(&self, _thread_no: usize) -> Box<dyn Sink> {
let mut new = Self::new(
self.suffix.clone(),
self.join_type.clone(),
self.join_args.clone(),
self.swapped,
self.join_columns_left.clone(),
self.join_columns_right.clone(),
Expand Down Expand Up @@ -317,7 +318,7 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
let mut hashes = std::mem::take(&mut self.hashes);
hashes.clear();

match self.join_type {
match self.join_args.how {
JoinType::Inner | JoinType::Left => {
let probe_operator = GenericJoinProbe::new(
left_df,
Expand All @@ -330,13 +331,14 @@ impl<K: ExtraPayload> Sink for GenericBuild<K> {
self.swapped,
hashes,
context,
self.join_type.clone(),
self.join_args.how.clone(),
self.join_nulls,
);
self.placeholder.replace(Box::new(probe_operator));
Ok(FinalizedSink::Operator)
},
JoinType::Outer { coalesce } => {
JoinType::Outer => {
let coalesce = self.join_args.coalesce.coalesce(&JoinType::Outer);
let probe_operator = GenericOuterJoinProbe::new(
left_df,
materialized_join_cols,
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-pipe/src/pipeline/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,12 @@ where
};

match jt {
join_type @ JoinType::Inner | join_type @ JoinType::Left => {
JoinType::Inner | JoinType::Left => {
let (join_columns_left, join_columns_right) = swap_eval();

Box::new(GenericBuild::<()>::new(
Arc::from(options.args.suffix()),
join_type.clone(),
options.args.clone(),
swapped,
join_columns_left,
join_columns_right,
Expand All @@ -317,7 +317,7 @@ where

Box::new(GenericBuild::<Tracker>::new(
Arc::from(options.args.suffix()),
jt.clone(),
options.args.clone(),
swapped,
join_columns_left,
join_columns_right,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ pub(super) fn process_join(
already_added_local_to_local_projected.insert(local_name);
}
// In outer joins both columns remain. So `add_local=true` also for the right table
let add_local = matches!(options.args.how, JoinType::Outer { coalesce: false });
let add_local = matches!(options.args.how, JoinType::Outer)
&& !options.args.coalesce.coalesce(&options.args.how);
for e in &right_on {
// In case of outer joins we also add the columns.
// But before we do that we must check if the column wasn't already added by the lhs.
Expand Down
9 changes: 5 additions & 4 deletions crates/polars-plan/src/logical_plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,12 @@ pub(crate) fn det_join_schema(
new_schema.with_column(field.name, field.dtype);
arena.clear();
}
let coalesces_join_keys = options.args.coalesce.coalesce(&options.args.how);
// except in asof joins. Asof joins are not equi-joins
// so the columns that are joined on, may have different
// values so if the right has a different name, it is added to the schema
#[cfg(feature = "asof_join")]
if !options.args.how.merges_join_keys() {
if !coalesces_join_keys {
for (left_on, right_on) in left_on.iter().zip(right_on) {
let field_left =
left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?;
Expand All @@ -330,9 +331,9 @@ pub(crate) fn det_join_schema(
}

for (name, dtype) in schema_right.iter() {
if !join_on_right.contains(name.as_str()) // The names that are joined on are merged
|| matches!(&options.args.how, JoinType::Outer{coalesce: false})
// The names are not merged
if !coalesces_join_keys || // The names are not merged
!join_on_right.contains(name.as_str())
// The names that are joined on are merged
{
if schema_left.contains(name.as_str()) {
#[cfg(feature = "asof_join")]
Expand Down
11 changes: 3 additions & 8 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,9 @@ impl SQLContext {
let (r_name, rf) = self.get_table(&tbl.relation)?;
lf = match &tbl.join_operator {
JoinOperator::CrossJoin => lf.cross_join(rf),
JoinOperator::FullOuter(constraint) => process_join(
lf,
rf,
constraint,
&l_name,
&r_name,
JoinType::Outer { coalesce: false },
)?,
JoinOperator::FullOuter(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Outer)?
},
JoinOperator::Inner(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)?
},
Expand Down
Loading

0 comments on commit 8a3599b

Please sign in to comment.