Skip to content

Commit

Permalink
fix: Turn off cse if cache node found (#15554)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 9, 2024
1 parent 86d3aa0 commit 42d3697
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
33 changes: 18 additions & 15 deletions crates/polars-plan/src/logical_plan/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,24 @@ pub fn optimize(
}

#[cfg(feature = "cse")]
let _cse_plan_changed =
if comm_subplan_elim && members.has_joins_or_unions && members.has_duplicate_scans() {
if verbose {
eprintln!("found multiple sources; run comm_subplan_elim")
}
let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena);

prune_unused_caches(lp_arena, cid2c);

lp_top = lp;
members.has_cache |= changed;
changed
} else {
false
};
let _cse_plan_changed = if comm_subplan_elim
&& members.has_joins_or_unions
&& members.has_duplicate_scans()
&& !members.has_cache
{
if verbose {
eprintln!("found multiple sources; run comm_subplan_elim")
}
let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena);

prune_unused_caches(lp_arena, cid2c);

lp_top = lp;
members.has_cache |= changed;
changed
} else {
false
};
#[cfg(not(feature = "cse"))]
let _cse_plan_changed = false;

Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,7 +2371,11 @@ def lazy(self) -> Self:
return self

def cache(self) -> Self:
"""Cache the result once the execution of the physical plan hits this node."""
"""
Cache the result once the execution of the physical plan hits this node.
It is not recommended using this as the optimizer likely can do a better job.
"""
return self._from_pyldf(self._ldf.cache())

def cast(
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,12 @@ def test_cse_15536() -> None:
data.filter(pl.lit(True) & (pl.col("a") == 7) | (pl.col("a") == 8)),
]
).collect()["a"].to_list() == [6, 9, 7, 8]


def test_cse_15548() -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3]})
ldf2 = ldf.filter(pl.col("a") == 1).cache()
ldf3 = pl.concat([ldf, ldf2])

assert len(ldf3.collect(comm_subplan_elim=False)) == 4
assert len(ldf3.collect(comm_subplan_elim=True)) == 4

0 comments on commit 42d3697

Please sign in to comment.