From 85a5e389af31ee7ae34addc82dfed77c503db51b Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 28 Mar 2024 14:54:20 +0000 Subject: [PATCH 01/30] fix: ensure first datapoint is always included in group_by_dynamic (#15312) --- crates/polars-time/src/windows/bounds.rs | 10 +- crates/polars-time/src/windows/group_by.rs | 3 +- crates/polars-time/src/windows/test.rs | 40 +++-- crates/polars-time/src/windows/window.rs | 143 ++++++++++++++---- py-polars/polars/dataframe/frame.py | 12 +- py-polars/polars/lazyframe/frame.py | 14 +- .../unit/operations/test_group_by_dynamic.py | 35 ++++- 7 files changed, 205 insertions(+), 52 deletions(-) diff --git a/crates/polars-time/src/windows/bounds.rs b/crates/polars-time/src/windows/bounds.rs index eba76ac7fb72..07757620cfe1 100644 --- a/crates/polars-time/src/windows/bounds.rs +++ b/crates/polars-time/src/windows/bounds.rs @@ -63,7 +63,15 @@ impl Bounds { pub(crate) fn is_future(&self, t: i64, closed: ClosedWindow) -> bool { match closed { ClosedWindow::Left | ClosedWindow::None => self.stop <= t, - ClosedWindow::Both | ClosedWindow::Right => t > self.stop, + ClosedWindow::Both | ClosedWindow::Right => self.stop < t, + } + } + + #[inline] + pub(crate) fn is_past(&self, t: i64, closed: ClosedWindow) -> bool { + match closed { + ClosedWindow::Left | ClosedWindow::Both => self.start > t, + ClosedWindow::None | ClosedWindow::Right => self.start >= t, } } } diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 0da725707eb7..c7cb2429fa22 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -180,6 +180,7 @@ pub fn group_by_windows( window .get_overlapping_bounds_iter( boundary, + closed_window, tu, tz.parse::().ok().as_ref(), start_by, @@ -198,7 +199,7 @@ pub fn group_by_windows( _ => { update_groups_and_bounds( window - .get_overlapping_bounds_iter(boundary, tu, None, start_by) + .get_overlapping_bounds_iter(boundary, closed_window, tu, None, start_by) .unwrap(), start_offset, time, diff --git a/crates/polars-time/src/windows/test.rs b/crates/polars-time/src/windows/test.rs index 7b573c14d49f..d0b8dbd67b67 100644 --- a/crates/polars-time/src/windows/test.rs +++ b/crates/polars-time/src/windows/test.rs @@ -148,8 +148,8 @@ fn test_groups_large_interval() { false, Default::default(), ); - assert_eq!(groups.len(), 2); - assert_eq!(groups[1], [2, 2]); + assert_eq!(groups.len(), 3); + assert_eq!(groups[1], [1, 1]); } #[test] @@ -167,7 +167,9 @@ fn test_offset() { Duration::parse("-2m"), ); - let b = w.get_earliest_bounds_ns(t, None).unwrap(); + let b = w + .get_earliest_bounds_ns(t, ClosedWindow::Left, None) + .unwrap(); let start = NaiveDate::from_ymd_opt(2020, 1, 1) .unwrap() .and_hms_opt(23, 58, 0) @@ -209,7 +211,9 @@ fn test_boundaries() { ); // earliest bound is first datapoint: 2021-12-16 00:00:00 - let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); + let b = w + .get_earliest_bounds_ns(ts[0], ClosedWindow::Both, None) + .unwrap(); assert_eq!(b.start, start.and_utc().timestamp_nanos_opt().unwrap()); // test closed: "both" (includes both ends of the interval) @@ -340,9 +344,10 @@ fn test_boundaries() { false, Default::default(), ); - assert_eq!(groups[0], [1, 2]); // 00:00:00 -> 00:30:00 - assert_eq!(groups[1], [3, 2]); // 01:00:00 -> 01:30:00 - assert_eq!(groups[2], [5, 2]); // 02:00:00 -> 02:30:00 + assert_eq!(groups[0], [0, 1]); // (2021-12-15 23:30, 2021-12-16 00:00] + assert_eq!(groups[1], [1, 2]); // (2021-12-16 00:00, 2021-12-16 00:30] + assert_eq!(groups[2], [3, 2]); // (2021-12-16 00:30, 2021-12-16 01:00] + assert_eq!(groups[3], [5, 2]); // (2021-12-16 01:00, 2021-12-16 01:30] // test closed: "none" (should not include left or right end of interval) let (groups, _, _) = group_by_windows( @@ -388,14 +393,18 @@ fn test_boundaries_2() { // period 1h // offset 30m let offset = Duration::parse("30m"); - let w = Window::new(Duration::parse("2h"), Duration::parse("1h"), offset); + let every = Duration::parse("2h"); + let w = Window::new(every, Duration::parse("1h"), offset); // earliest bound is first datapoint: 2021-12-16 00:00:00 + 30m offset: 2021-12-16 00:30:00 - let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); + // We then shift back by `every` (2h): 2021-12-15 22:30:00 + let b = w + .get_earliest_bounds_ns(ts[0], ClosedWindow::Both, None) + .unwrap(); assert_eq!( b.start, - start.and_utc().timestamp_nanos_opt().unwrap() + offset.duration_ns() + start.and_utc().timestamp_nanos_opt().unwrap() + offset.duration_ns() - every.duration_ns() ); let (groups, lower, higher) = group_by_windows( @@ -520,7 +529,9 @@ fn test_boundaries_ms() { ); // earliest bound is first datapoint: 2021-12-16 00:00:00 - let b = w.get_earliest_bounds_ms(ts[0], None).unwrap(); + let b = w + .get_earliest_bounds_ms(ts[0], ClosedWindow::Both, None) + .unwrap(); assert_eq!(b.start, start.and_utc().timestamp_millis()); // test closed: "both" (includes both ends of the interval) @@ -651,9 +662,10 @@ fn test_boundaries_ms() { false, Default::default(), ); - assert_eq!(groups[0], [1, 2]); // 00:00:00 -> 00:30:00 - assert_eq!(groups[1], [3, 2]); // 01:00:00 -> 01:30:00 - assert_eq!(groups[2], [5, 2]); // 02:00:00 -> 02:30:00 + assert_eq!(groups[0], [0, 1]); // (2021-12-15 23:30, 2021-12-16 00:00] + assert_eq!(groups[1], [1, 2]); // (2021-12-16 00:00, 2021-12-16 00:30] + assert_eq!(groups[2], [3, 2]); // (2021-12-16 00:30, 2021-12-16 01:00] + assert_eq!(groups[3], [5, 2]); // (2021-12-16 01:00, 2021-12-16 01:30] // test closed: "none" (should not include left or right end of interval) let (groups, _, _) = group_by_windows( diff --git a/crates/polars-time/src/windows/window.rs b/crates/polars-time/src/windows/window.rs index 8adb7520ecfe..16d43c4da3d8 100644 --- a/crates/polars-time/src/windows/window.rs +++ b/crates/polars-time/src/windows/window.rs @@ -8,6 +8,37 @@ use polars_core::prelude::*; use crate::prelude::*; +/// Ensure that earliest datapoint (`t`) is in, or in front of, first window. +/// +/// For example, if we have: +/// +/// - first datapoint is `2020-01-01 01:00` +/// - `every` is `'1d'` +/// - `period` is `'2d'` +/// - `offset` is `'6h'` +/// +/// then truncating the earliest datapoint by `every` and adding `offset` results +/// in the window `[2020-01-01 06:00, 2020-01-03 06:00)`. To give the earliest datapoint +/// a chance of being included, we then shift the window back by `every` to +/// `[2019-12-31 06:00, 2020-01-02 06:00)`. +pub(crate) fn ensure_t_in_or_in_front_of_window( + mut every: Duration, + t: i64, + offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult, + period: Duration, + mut start: i64, + closed_window: ClosedWindow, + tz: Option<&Tz>, +) -> PolarsResult { + every.negative = !every.negative; + let mut stop = offset_fn(&period, start, tz)?; + while Bounds::new(start, stop).is_past(t, closed_window) { + start = offset_fn(&every, start, tz)?; + stop = offset_fn(&period, start, tz)?; + } + Ok(Bounds::new_checked(start, stop)) +} + /// Represents a window in time #[derive(Copy, Clone)] pub struct Window { @@ -82,24 +113,58 @@ impl Window { /// returns the bounds for the earliest window bounds /// that contains the given time t. For underlapping windows that /// do not contain time t, the window directly after time t will be returned. - pub fn get_earliest_bounds_ns(&self, t: i64, tz: Option<&Tz>) -> PolarsResult { + pub fn get_earliest_bounds_ns( + &self, + t: i64, + closed_window: ClosedWindow, + tz: Option<&Tz>, + ) -> PolarsResult { let start = self.truncate_ns(t, tz)?; - let stop = self.period.add_ns(start, tz)?; - - Ok(Bounds::new_checked(start, stop)) + ensure_t_in_or_in_front_of_window( + self.every, + t, + Duration::add_ns, + self.period, + start, + closed_window, + tz, + ) } - pub fn get_earliest_bounds_us(&self, t: i64, tz: Option<&Tz>) -> PolarsResult { + pub fn get_earliest_bounds_us( + &self, + t: i64, + closed_window: ClosedWindow, + tz: Option<&Tz>, + ) -> PolarsResult { let start = self.truncate_us(t, tz)?; - let stop = self.period.add_us(start, tz)?; - Ok(Bounds::new_checked(start, stop)) + ensure_t_in_or_in_front_of_window( + self.every, + t, + Duration::add_us, + self.period, + start, + closed_window, + tz, + ) } - pub fn get_earliest_bounds_ms(&self, t: i64, tz: Option<&Tz>) -> PolarsResult { + pub fn get_earliest_bounds_ms( + &self, + t: i64, + closed_window: ClosedWindow, + tz: Option<&Tz>, + ) -> PolarsResult { let start = self.truncate_ms(t, tz)?; - let stop = self.period.add_ms(start, tz)?; - - Ok(Bounds::new_checked(start, stop)) + ensure_t_in_or_in_front_of_window( + self.every, + t, + Duration::add_ms, + self.period, + start, + closed_window, + tz, + ) } pub(crate) fn estimate_overlapping_bounds_ns(&self, boundary: Bounds) -> usize { @@ -120,11 +185,12 @@ impl Window { pub fn get_overlapping_bounds_iter<'a>( &'a self, boundary: Bounds, + closed_window: ClosedWindow, tu: TimeUnit, tz: Option<&'a Tz>, start_by: StartBy, ) -> PolarsResult { - BoundsIter::new(*self, boundary, tu, tz, start_by) + BoundsIter::new(*self, closed_window, boundary, tu, tz, start_by) } } @@ -140,6 +206,7 @@ pub struct BoundsIter<'a> { impl<'a> BoundsIter<'a> { fn new( window: Window, + closed_window: ClosedWindow, boundary: Bounds, tu: TimeUnit, tz: Option<&'a Tz>, @@ -157,14 +224,20 @@ impl<'a> BoundsIter<'a> { boundary }, StartBy::WindowBound => match tu { - TimeUnit::Nanoseconds => window.get_earliest_bounds_ns(boundary.start, tz)?, - TimeUnit::Microseconds => window.get_earliest_bounds_us(boundary.start, tz)?, - TimeUnit::Milliseconds => window.get_earliest_bounds_ms(boundary.start, tz)?, + TimeUnit::Nanoseconds => { + window.get_earliest_bounds_ns(boundary.start, closed_window, tz)? + }, + TimeUnit::Microseconds => { + window.get_earliest_bounds_us(boundary.start, closed_window, tz)? + }, + TimeUnit::Milliseconds => { + window.get_earliest_bounds_ms(boundary.start, closed_window, tz)? + }, }, _ => { { #[allow(clippy::type_complexity)] - let (from, to, offset): ( + let (from, to, offset_fn): ( fn(i64) -> NaiveDateTime, fn(NaiveDateTime) -> i64, fn(&Duration, i64, Option<&Tz>) -> PolarsResult, @@ -186,9 +259,8 @@ impl<'a> BoundsIter<'a> { ), }; // find beginning of the week. - let mut boundary = boundary; let dt = from(boundary.start); - (boundary.start, boundary.stop) = match tz { + match tz { #[cfg(feature = "timezones")] Some(tz) => { let dt = tz.from_utc_datetime(&dt); @@ -196,16 +268,24 @@ impl<'a> BoundsIter<'a> { let dt = dt.naive_utc(); let start = to(dt); // adjust start of the week based on given day of the week - let start = offset( + let start = offset_fn( &Duration::parse(&format!("{}d", start_by.weekday().unwrap())), start, Some(tz), )?; // apply the 'offset' - let start = offset(&window.offset, start, Some(tz))?; + let start = offset_fn(&window.offset, start, Some(tz))?; + // make sure the first datapoint has a chance to be included // and compute the end of the window defined by the 'period' - let stop = offset(&window.period, start, Some(tz))?; - (start, stop) + ensure_t_in_or_in_front_of_window( + window.every, + boundary.start, + offset_fn, + window.period, + start, + closed_window, + Some(tz), + )? }, _ => { let tz = chrono::Utc; @@ -214,20 +294,27 @@ impl<'a> BoundsIter<'a> { let dt = dt.naive_utc(); let start = to(dt); // adjust start of the week based on given day of the week - let start = offset( + let start = offset_fn( &Duration::parse(&format!("{}d", start_by.weekday().unwrap())), start, None, ) .unwrap(); // apply the 'offset' - let start = offset(&window.offset, start, None).unwrap(); + let start = offset_fn(&window.offset, start, None).unwrap(); + // make sure the first datapoint has a chance to be included // and compute the end of the window defined by the 'period' - let stop = offset(&window.period, start, None).unwrap(); - (start, stop) + ensure_t_in_or_in_front_of_window( + window.every, + boundary.start, + offset_fn, + window.period, + start, + closed_window, + None, + )? }, - }; - boundary + } } }, }; diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 16f21574cd63..2d0e9cde550e 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -5554,8 +5554,8 @@ def group_by_dynamic( - [start + 2*every, start + 2*every + period) - ... - where `start` is determined by `start_by`, `offset`, and `every` (see parameter - descriptions below). + where `start` is determined by `start_by`, `offset`, `every`, and the earliest + datapoint. See the `start_by` argument description for details. .. warning:: The index column must be sorted in ascending order. If `by` is passed, then @@ -5577,7 +5577,7 @@ def group_by_dynamic( period length of the window, if None it will equal 'every' offset - offset of the window, only takes effect if `start_by` is `'window'`. + offset of the window, does not take effect if `start_by` is 'datapoint'. Defaults to negative `every`. truncate truncate the time value to the window lower bound @@ -5613,6 +5613,9 @@ def group_by_dynamic( * 'tuesday': Start the window on the Tuesday before the first data point. * ... * 'sunday': Start the window on the Sunday before the first data point. + + The resulting window is then shifted back until the earliest datapoint + is in or in front of it. check_sorted Check whether `index_column` is sorted (or, if `group_by` is given, check whether it's sorted within each group). @@ -10694,6 +10697,9 @@ def groupby_dynamic( * 'tuesday': Start the window on the Tuesday before the first data point. * ... * 'sunday': Start the window on the Sunday before the first data point. + + The resulting window is then shifted back until the earliest datapoint + is in or in front of it. check_sorted Check whether `index_column` is sorted (or, if `by` is given, check whether it's sorted within each group). diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 05d5124bf11e..4ec0a5632e46 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -3403,8 +3403,8 @@ def group_by_dynamic( - [start + 2*every, start + 2*every + period) - ... - where `start` is determined by `start_by`, `offset`, and `every` (see parameter - descriptions below). + where `start` is determined by `start_by`, `offset`, `every`, and the earliest + datapoint. See the `start_by` argument description for details. .. warning:: The index column must be sorted in ascending order. If `by` is passed, then @@ -3426,7 +3426,7 @@ def group_by_dynamic( period length of the window, if None it will equal 'every' offset - offset of the window, only takes effect if `start_by` is `'window'`. + offset of the window, does not take effect if `start_by` is 'datapoint'. Defaults to negative `every`. truncate truncate the time value to the window lower bound @@ -3462,6 +3462,9 @@ def group_by_dynamic( * 'tuesday': Start the window on the Tuesday before the first data point. * ... * 'sunday': Start the window on the Sunday before the first data point. + + The resulting window is then shifted back until the earliest datapoint + is in or in front of it. check_sorted Check whether `index_column` is sorted (or, if `group_by` is given, check whether it's sorted within each group). @@ -6447,7 +6450,7 @@ def groupby_dynamic( period length of the window, if None it will equal 'every' offset - offset of the window, only takes effect if `start_by` is `'window'`. + offset of the window, does not take effect if `start_by` is 'datapoint'. Defaults to negative `every`. truncate truncate the time value to the window lower bound @@ -6472,6 +6475,9 @@ def groupby_dynamic( * 'tuesday': Start the window on the Tuesday before the first data point. * ... * 'sunday': Start the window on the Sunday before the first data point. + + The resulting window is then shifted back until the earliest datapoint + is in or in front of it. check_sorted Check whether `index_column` is sorted (or, if `by` is given, check whether it's sorted within each group). diff --git a/py-polars/tests/unit/operations/test_group_by_dynamic.py b/py-polars/tests/unit/operations/test_group_by_dynamic.py index 6ed1a3d6ab53..12b26a4b7ded 100644 --- a/py-polars/tests/unit/operations/test_group_by_dynamic.py +++ b/py-polars/tests/unit/operations/test_group_by_dynamic.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, timezone from typing import TYPE_CHECKING, Any import numpy as np @@ -990,3 +990,36 @@ def test_group_by_dynamic_check_sorted_15225() -> None: assert_frame_equal(result, expected) with pytest.raises(pl.InvalidOperationError, match="not explicitly sorted"): result = df.group_by_dynamic("b", every="2d").agg(pl.sum("a")) + + +@pytest.mark.parametrize("start_by", ["window", "friday"]) +def test_earliest_point_included_when_offset_is_set_15241(start_by: StartBy) -> None: + df = pl.DataFrame( + data={ + "t": pl.Series( + [ + datetime(2024, 3, 22, 3, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 4, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 5, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 6, 0, tzinfo=timezone.utc), + ] + ), + "v": [1, 10, 100, 1000], + } + ).set_sorted("t") + result = df.group_by_dynamic( + index_column="t", + every="1d", + offset=timedelta(hours=5), + start_by=start_by, + ).agg("v") + expected = pl.DataFrame( + { + "t": [ + datetime(2024, 3, 21, 5, 0, tzinfo=timezone.utc), + datetime(2024, 3, 22, 5, 0, tzinfo=timezone.utc), + ], + "v": [[1, 10], [100, 1000]], + } + ) + assert_frame_equal(result, expected) From 3888b8665d150402b48543374eef71615a77e5d1 Mon Sep 17 00:00:00 2001 From: Marshall Date: Thu, 28 Mar 2024 11:08:17 -0400 Subject: [PATCH 02/30] fix: Return correct dtype for `s.clear()` when dtype is `Object` (#15315) --- crates/polars-core/src/series/mod.rs | 20 +++++++++---------- py-polars/tests/unit/operations/test_clear.py | 1 - 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index b9fbcb179177..b72a83b4a3c9 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -157,17 +157,17 @@ impl Series { } pub fn clear(&self) -> Series { - // Only the inner of objects know their type, so use this hack. - #[cfg(feature = "object")] - if matches!(self.dtype(), DataType::Object(_, _)) { - return if self.is_empty() { - self.clone() - } else { - let av = self.get(0).unwrap(); - Series::new(self.name(), [av]).slice(0, 0) - }; + if self.is_empty() { + self.clone() + } else { + match self.dtype() { + #[cfg(feature = "object")] + DataType::Object(_, _) => self + .take(&ChunkedArray::::new_vec("", vec![])) + .unwrap(), + dt => Series::new_empty(self.name(), dt), + } } - Series::new_empty(self.name(), self.dtype()) } #[doc(hidden)] diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py index 76aba9b6e387..0ac3c1d27ba0 100644 --- a/py-polars/tests/unit/operations/test_clear.py +++ b/py-polars/tests/unit/operations/test_clear.py @@ -65,7 +65,6 @@ def test_clear_lf() -> None: assert ldfe.collect().rows() == [(None, None, None), (None, None, None)] -@pytest.mark.skip("Currently bugged: https://github.com/pola-rs/polars/issues/15303") def test_clear_series_object_starting_with_null() -> None: s = pl.Series([None, object()]) From 7c47d002b9ef1259edf21d035c6cc679a7f907d3 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 28 Mar 2024 16:09:15 +0100 Subject: [PATCH 03/30] build(python): Update Cargo lock (#15370) --- Cargo.lock | 328 ++++++++++++++------------ crates/polars-time/src/month_start.rs | 4 +- 2 files changed, 174 insertions(+), 158 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4451eec39152..0f1846db1aa9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,9 +39,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -96,9 +96,9 @@ checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anyhow" -version = "1.0.80" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" +checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" [[package]] name = "apache-avro" @@ -142,9 +142,9 @@ checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" [[package]] name = "arrow-array" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d390feeb7f21b78ec997a4081a025baef1e2e0d6069e181939b61864c9779609" +checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" dependencies = [ "ahash", "arrow-buffer", @@ -158,9 +158,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69615b061701bcdffbc62756bc7e85c827d5290b472b580c972ebbbf690f5aa4" +checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" dependencies = [ "bytes", "half", @@ -169,9 +169,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67d644b91a162f3ad3135ce1184d0a31c28b816a581e08f29e8e9277a574c64e" +checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -181,9 +181,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff3e9c01f7cd169379d269f926892d0e622a704960350d09d331be3ec9e0029" +checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" [[package]] name = "arrow2" @@ -224,18 +224,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -255,9 +255,9 @@ checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" [[package]] name = "autocfg" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] name = "avro-schema" @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "aws-config" -version = "1.1.7" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b96342ea8948ab9bef3e6234ea97fc32e2d8a88d8fb6a084e52267317f94b6b" +checksum = "297b64446175a73987cedc3c438d79b2a654d0fff96f65ff530fbe039347644c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -302,14 +302,15 @@ dependencies = [ "time", "tokio", "tracing", + "url", "zeroize", ] [[package]] name = "aws-credential-types" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273fa47dafc9ef14c2c074ddddbea4561ff01b7f68d5091c0e9737ced605c01d" +checksum = "fa8587ae17c8e967e4b05a62d495be2fb7701bec52a97f7acfe8a29f938384c8" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -319,9 +320,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e38bab716c8bf07da24be07ecc02e0f5656ce8f30a891322ecdcb202f943b85" +checksum = "b13dc54b4b49f8288532334bba8f87386a40571c47c37b1304979b556dc613c8" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -343,10 +344,11 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.17.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d35d39379445970fc3e4ddf7559fff2c32935ce0b279f9cb27080d6b7c6d94" +checksum = "bc075ffee2a40cb1590bed35d7ec953589a564e768fa91947c565425cd569269" dependencies = [ + "ahash", "aws-credential-types", "aws-runtime", "aws-sigv4", @@ -361,20 +363,25 @@ dependencies = [ "aws-smithy-xml", "aws-types", "bytes", + "fastrand", + "hex", + "hmac", "http 0.2.12", "http-body", + "lru", "once_cell", "percent-encoding", "regex-lite", + "sha2", "tracing", "url", ] [[package]] name = "aws-sdk-sso" -version = "1.15.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d84bd3925a17c9adbf6ec65d52104a44a09629d8f70290542beeee69a95aee7f" +checksum = "019a07902c43b03167ea5df0182f0cb63fae89f9a9682c44d18cf2e4a042cb34" dependencies = [ "aws-credential-types", "aws-runtime", @@ -394,9 +401,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.15.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c2dae39e997f58bc4d6292e6244b26ba630c01ab671b6f9f44309de3eb80ab8" +checksum = "04c46ee08a48a7f4eaa4ad201dcc1dd537b49c50859d14d4510e00ad9d3f9af2" dependencies = [ "aws-credential-types", "aws-runtime", @@ -416,9 +423,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.15.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17fd9a53869fee17cea77e352084e1aa71e2c5e323d974c13a9c2bcfd9544c7f" +checksum = "f752ac730125ca6017f72f9db5ec1772c9ecc664f87aa7507a7d81b023c23713" dependencies = [ "aws-credential-types", "aws-runtime", @@ -439,9 +446,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.1.7" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ada00a4645d7d89f296fe0ddbc3fe3554f03035937c849a05d37ddffc1f29a1" +checksum = "11d6f29688a4be9895c0ba8bef861ad0c0dac5c15e9618b9b7a6c233990fc263" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -468,9 +475,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.1.7" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcf7f09a27286d84315dfb9346208abb3b0973a692454ae6d0bc8d803fcce3b4" +checksum = "f7a41ccd6b74401a49ca828617049e5c23d83163d330a4f90a8081aadee0ac45" dependencies = [ "futures-util", "pin-project-lite", @@ -479,9 +486,9 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.60.6" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fd4b66f2a8e7c84d7e97bda2666273d41d2a2e25302605bcf906b7b2661ae5e" +checksum = "83fa43bc04a6b2441968faeab56e68da3812f978a670a5db32accbdcafddd12f" dependencies = [ "aws-smithy-http", "aws-smithy-types", @@ -511,9 +518,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.6" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6ca214a6a26f1b7ebd63aa8d4f5e2194095643023f9608edf99a58247b9d80d" +checksum = "3f10fa66956f01540051b0aa7ad54574640f748f9839e843442d99b970d3aff9" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -532,18 +539,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.6" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1af80ecf3057fb25fe38d1687e94c4601a7817c6a1e87c1b0635f7ecb644ace5" +checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.6" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb27084f72ea5fc20033efe180618677ff4a2f474b53d84695cfe310a6526cbc" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" dependencies = [ "aws-smithy-types", "urlencoding", @@ -551,9 +558,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbb5fca54a532a36ff927fbd7407a7c8eb9c3b4faf72792ba2965ea2cad8ed55" +checksum = "ec81002d883e5a7fd2bb063d6fb51c4999eb55d404f4fff3dd878bf4733b9f01" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -576,9 +583,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.1.7" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22389cb6f7cac64f266fb9f137745a9349ced7b47e0d2ba503e9e40ede4f7060" +checksum = "9acb931e0adaf5132de878f1398d83f8677f90ba70f01f65ff87f6d7244be1c5" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -593,9 +600,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f081da5481210523d44ffd83d9f0740320050054006c719eae0232d411f024d3" +checksum = "abe14dceea1e70101d38fbf2a99e6a34159477c0fb95e68e05c66bd7ae4c3729" dependencies = [ "base64-simd", "bytes", @@ -616,18 +623,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.6" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fccd8f595d0ca839f9f2548e66b99514a85f92feb4c01cf2868d93eb4888a42" +checksum = "872c68cf019c0e4afc5de7753c4f7288ce4b71663212771bf5e4542eb9346ca9" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.1.7" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d07c63521aa1ea9a9f92a701f1a08ce3fd20b46c6efc0d5c8947c1fd879e3df1" +checksum = "0dbf2f3da841a8930f159163175cf6a3d16ddde517c1b0fba7aa776822800f40" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -640,9 +647,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -698,9 +705,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" dependencies = [ "serde", ] @@ -716,9 +723,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -754,29 +761,29 @@ checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] name = "bytemuck" -version = "1.14.3" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "bytes-utils" @@ -833,9 +840,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.35" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" +checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" dependencies = [ "android-tzdata", "iana-time-zone", @@ -895,9 +902,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.2" +version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b230ab84b0ffdf890d5a10abdbc8b83ae1c4918275daea1ab8801f71536b2651" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", ] @@ -1117,7 +1124,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "crossterm_winapi", "libc", "parking_lot", @@ -1275,7 +1282,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -1324,9 +1331,9 @@ checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" [[package]] name = "fastrand" -version = "2.0.1" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" [[package]] name = "ff" @@ -1435,7 +1442,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -1499,11 +1506,11 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "git2" -version = "0.18.2" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b3ba52851e73b46a4c3df1d89343741112003f0f6f13beb0dfac9e457c3fdcd" +checksum = "232e6a7bfe35766bf715e55a88b39a700596c0ccfd88cd3680b4cdb40d66ef70" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "libc", "libgit2-sys", "log", @@ -1529,9 +1536,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb" dependencies = [ "bytes", "fnv", @@ -1755,9 +1762,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.5" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -1766,9 +1773,9 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "inventory" @@ -1813,9 +1820,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "itoap" @@ -2042,9 +2049,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.15" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037731f5d3aaa87a5675e895b63ddff1a87624bc29f77004ea829809654e48f6" +checksum = "5e143b5e666b2695d28f6bca6497720813f699c9602dd7f5cac91008b8ada7f9" dependencies = [ "cc", "libc", @@ -2074,6 +2081,15 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "lru" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +dependencies = [ + "hashbrown 0.14.3", +] + [[package]] name = "lz4" version = "1.24.0" @@ -2116,9 +2132,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "memmap2" @@ -2131,9 +2147,9 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" dependencies = [ "autocfg", ] @@ -2175,9 +2191,9 @@ dependencies = [ [[package]] name = "multiversion" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2c7b9d7fe61760ce5ea19532ead98541f6b4c495d87247aff9826445cf6872a" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" dependencies = [ "multiversion-macros", "target-features", @@ -2185,9 +2201,9 @@ dependencies = [ [[package]] name = "multiversion-macros" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a83d8500ed06d68877e9de1dde76c1dbb83885dcdbda4ef44ccbc3fbda2ac8" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" dependencies = [ "proc-macro2", "quote", @@ -2629,7 +2645,7 @@ dependencies = [ "proptest", "rand", "regex", - "regex-syntax 0.8.2", + "regex-syntax 0.8.3", "ryu", "sample-arrow2", "sample-std 0.1.1", @@ -2678,7 +2694,7 @@ dependencies = [ "ahash", "arrow-array", "bincode", - "bitflags 2.4.2", + "bitflags 2.5.0", "bytemuck", "chrono", "chrono-tz", @@ -2807,7 +2823,7 @@ name = "polars-lazy" version = "0.38.3" dependencies = [ "ahash", - "bitflags 2.4.2", + "bitflags 2.5.0", "futures", "glob", "once_cell", @@ -3031,9 +3047,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -3044,13 +3060,13 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "lazy_static", "num-traits", "rand", "rand_chacha", "rand_xorshift", - "regex-syntax 0.8.2", + "regex-syntax 0.8.3", "unarray", ] @@ -3171,7 +3187,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3184,7 +3200,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -3288,7 +3304,7 @@ version = "11.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", ] [[package]] @@ -3299,9 +3315,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -3343,19 +3359,19 @@ checksum = "5fddb4f8d99b0a2ebafc65a87a69a7b9875e4b1ae1f00db265d300ef7f28bccc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", "regex-automata", - "regex-syntax 0.8.2", + "regex-syntax 0.8.3", ] [[package]] @@ -3366,7 +3382,7 @@ checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.2", + "regex-syntax 0.8.3", ] [[package]] @@ -3383,15 +3399,15 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "reqwest" -version = "0.11.24" +version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ "base64", "bytes", @@ -3485,11 +3501,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", @@ -3541,9 +3557,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.3.1" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8" +checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" [[package]] name = "rustls-webpki" @@ -3729,14 +3745,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" dependencies = [ "indexmap", "itoa", @@ -3808,9 +3824,9 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.13.8" +version = "0.13.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2faf8f101b9bc484337a6a6b0409cf76c139f2fb70a9e3aee6b6774be7bfbf76" +checksum = "b0b84c23a1066e1d650ebc99aa8fb9f8ed0ab96fd36e2e836173c92fc9fb29bc" dependencies = [ "ahash", "getrandom", @@ -3847,9 +3863,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "smartstring" @@ -3982,7 +3998,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -4004,9 +4020,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" dependencies = [ "proc-macro2", "quote", @@ -4056,9 +4072,9 @@ dependencies = [ [[package]] name = "target-features" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" @@ -4080,22 +4096,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -4188,7 +4204,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -4275,7 +4291,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -4310,7 +4326,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -4348,9 +4364,9 @@ dependencies = [ [[package]] name = "unicode-reverse" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bea5dacebb0d2d0a69a6700a05b59b3908bf801bf563a49bd27a1b60122962c" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" dependencies = [ "unicode-segmentation", ] @@ -4398,9 +4414,9 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] name = "uuid" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", "serde", @@ -4482,7 +4498,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "wasm-bindgen-shared", ] @@ -4516,7 +4532,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4780,7 +4796,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -4791,27 +4807,27 @@ checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zstd" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "7.0.0" +version = "7.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" +version = "2.0.10+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" dependencies = [ "cc", "pkg-config", diff --git a/crates/polars-time/src/month_start.rs b/crates/polars-time/src/month_start.rs index fbe991847cb8..76a8cf8be942 100644 --- a/crates/polars-time/src/month_start.rs +++ b/crates/polars-time/src/month_start.rs @@ -30,7 +30,7 @@ pub(crate) fn roll_backward( ts.hour(), ts.minute(), ts.second(), - ts.timestamp_subsec_nanos(), + ts.and_utc().timestamp_subsec_nanos(), ) .ok_or_else(|| { polars_err!( @@ -40,7 +40,7 @@ pub(crate) fn roll_backward( ts.hour(), ts.minute(), ts.second(), - ts.timestamp_subsec_nanos() + ts.and_utc().timestamp_subsec_nanos() ) ) })?; From d999c010028e7ced396642503027a602980d63d6 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 28 Mar 2024 16:14:51 +0100 Subject: [PATCH 04/30] python polars 0.20.17 (#15372) --- Cargo.lock | 2 +- py-polars/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f1846db1aa9..eb37fee66fab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3104,7 +3104,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "0.20.16" +version = "0.20.17" dependencies = [ "ahash", "built", diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 9ab764651c98..a48000225eab 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.20.16" +version = "0.20.17" edition = "2021" [lib] From bf3f200e672e2a46d12a2e5f94d404e35bec2b7a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:17:56 +0100 Subject: [PATCH 05/30] chore(rust): bump sample-test from 0.1.1 to 0.2.1 (#15369) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ crates/polars-arrow/Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eb37fee66fab..d0692b3eef05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3630,24 +3630,24 @@ dependencies = [ [[package]] name = "sample-test" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713e500947ff19fc1ae2805afa33ef45f3bb2ec656c77d92252d24cf9e3091b2" +checksum = "e8b253ca516416756b09b582e2b7275de8f51f35e5d5711e20712b9377c7d5bf" dependencies = [ "quickcheck", - "sample-std 0.1.1", + "sample-std 0.2.1", "sample-test-macros", ] [[package]] name = "sample-test-macros" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1a2c832a259aae95b6ed1da3aa377111ffde38d4282fa734faa3fff356534e" +checksum = "5cc6439a7589bb4581fdadb6391700ce4d26f8bffd34e2a75acb320822e9b5ef" dependencies = [ "proc-macro2", "quote", - "sample-std 0.1.1", + "sample-std 0.2.1", "syn 1.0.109", ] diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 2be06f69a744..57a06966fd57 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -87,7 +87,7 @@ rand = { workspace = true } # use for generating and testing random data samples sample-arrow2 = "0.17" sample-std = "0.1" -sample-test = "0.1" +sample-test = "0.2" # used to test async readers tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } tokio-util = { workspace = true, features = ["compat"] } From a8713597245ade5c680044b4428a89234195ce46 Mon Sep 17 00:00:00 2001 From: Rob <124158982+rob-sil@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:33:43 -0700 Subject: [PATCH 06/30] docs(python): Add "See Also" for `arg_sort` and `arg_sort_by` (#15348) --- py-polars/polars/expr/expr.py | 20 ++++++++++++++++++++ py-polars/polars/functions/lazy.py | 23 ++++++++++++++++++++++- py-polars/polars/series/series.py | 5 +++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index f34e3eb1f2ff..39ddacbb99d1 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2131,11 +2131,17 @@ def arg_sort(self, *, descending: bool = False, nulls_last: bool = False) -> Sel Expr Expression of data type :class:`UInt32`. + See Also + -------- + Expr.gather: Take values by index. + Expr.rank : Get the rank of each row. + Examples -------- >>> df = pl.DataFrame( ... { ... "a": [20, 10, 30], + ... "b": [1, 2, 3], ... } ... ) >>> df.select(pl.col("a").arg_sort()) @@ -2149,6 +2155,20 @@ def arg_sort(self, *, descending: bool = False, nulls_last: bool = False) -> Sel │ 0 │ │ 2 │ └─────┘ + + Use gather to apply the arg sort to other columns. + + >>> df.select(pl.col("b").gather(pl.col("a").arg_sort())) + shape: (3, 1) + ┌─────┐ + │ b │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 2 │ + │ 1 │ + │ 3 │ + └─────┘ """ return self._from_pyexpr(self._pyexpr.arg_sort(descending, nulls_last)) diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 20cb0c62d080..ecb884f942b1 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1539,7 +1539,7 @@ def arg_sort_by( descending: bool | Sequence[bool] = False, ) -> Expr: """ - Return the row indices that would sort the columns. + Return the row indices that would sort the column(s). Parameters ---------- @@ -1552,6 +1552,11 @@ def arg_sort_by( Sort in descending order. When sorting by multiple columns, can be specified per column by passing a sequence of booleans. + See Also + -------- + Expr.gather: Take values by index. + Expr.rank : Get the rank of each row. + Examples -------- Pass a single column name to compute the arg sort by that column. @@ -1560,6 +1565,7 @@ def arg_sort_by( ... { ... "a": [0, 1, 1, 0], ... "b": [3, 2, 3, 2], + ... "c": [1, 2, 3, 4], ... } ... ) >>> df.select(pl.arg_sort_by("a")) @@ -1590,6 +1596,21 @@ def arg_sort_by( │ 0 │ │ 3 │ └─────┘ + + Use gather to apply the arg sort to other columns. + + >>> df.select(pl.col("c").gather(pl.arg_sort_by("a"))) + shape: (4, 1) + ┌─────┐ + │ c │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 1 │ + │ 4 │ + │ 2 │ + │ 3 │ + └─────┘ """ exprs = parse_as_list_of_expressions(exprs, *more_exprs) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 1dbfab9f880e..3bd8026651db 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3459,6 +3459,11 @@ def arg_sort(self, *, descending: bool = False, nulls_last: bool = False) -> Ser nulls_last Place null values last instead of first. + See Also + -------- + Series.gather: Take values by index. + Series.rank : Get the rank of each row. + Examples -------- >>> s = pl.Series("a", [5, 3, 4, 1, 2]) From bd1882cb3c4b13a55cd810be5281ff3d31e36d74 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 29 Mar 2024 15:11:03 +0800 Subject: [PATCH 07/30] docs(python): Change the example to series for `series/array.py` (#15383) --- py-polars/polars/series/array.py | 50 ++++++++++++++------------------ 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 793ae0507404..a7044c66588b 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -61,20 +61,14 @@ def sum(self) -> Series: Examples -------- - >>> df = pl.DataFrame( - ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(pl.Int64, 2)}, - ... ) - >>> df.select(pl.col("a").arr.sum()) - shape: (2, 1) - ┌─────┐ - │ a │ - │ --- │ - │ i64 │ - ╞═════╡ - │ 3 │ - │ 7 │ - └─────┘ + >>> s = pl.Series([[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + >>> s.arr.sum() + shape: (2,) + Series: '' [i64] + [ + 3 + 7 + ] """ def std(self, ddof: int = 1) -> Series: @@ -134,23 +128,21 @@ def unique(self, *, maintain_order: bool = False) -> Series: maintain_order Maintain order of data. This requires more work. + Returns + ------- + Series + Series of data type :class:`List`. + Examples -------- - >>> df = pl.DataFrame( - ... { - ... "a": [[1, 1, 2]], - ... }, - ... schema_overrides={"a": pl.Array(pl.Int64, 3)}, - ... ) - >>> df.select(pl.col("a").arr.unique()) - shape: (1, 1) - ┌───────────┐ - │ a │ - │ --- │ - │ list[i64] │ - ╞═══════════╡ - │ [1, 2] │ - └───────────┘ + >>> s = pl.Series([[1, 1, 2], [3, 4, 5]], dtype=pl.Array(pl.Int64, 3)) + >>> s.arr.unique() + shape: (2,) + Series: '' [list[i64]] + [ + [1, 2] + [3, 4, 5] + ] """ def n_unique(self) -> Series: From eba3e7678bf3da42d2ccd012c59031a444d9da29 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 29 Mar 2024 11:47:27 +0400 Subject: [PATCH 08/30] chore: Update CODEOWNERS (polars-sql) (#15384) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 72eedaf196c8..b5527bf62325 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,6 +1,6 @@ * @ritchie46 @stinodego @c-peters /crates/ @ritchie46 @stinodego @orlp @orlp @c-peters -/crates/polars-sql/ @ritchie46 @stinodego @orlp @c-peters @universalmind303 +/crates/polars-sql/ @ritchie46 @stinodego @orlp @c-peters @universalmind303 @alexander-beedie /crates/polars-time/ @ritchie46 @stinodego @orlp @c-peters @MarcoGorelli /py-polars/ @ritchie46 @stinodego @c-peters @alexander-beedie @MarcoGorelli @reswqa From f61594a7b55ddf351afe81cc8119c89cf46df60e Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 29 Mar 2024 15:49:05 +0800 Subject: [PATCH 09/30] fix: `sort` for series with unsupported dtype should raise instead of panic (#15385) --- .../src/chunked_array/array/iterator.rs | 23 +++++++++++++++++++ crates/polars-core/src/frame/mod.rs | 2 +- .../src/series/implementations/binary.rs | 4 ++-- .../series/implementations/binary_offset.rs | 4 ++-- .../src/series/implementations/boolean.rs | 4 ++-- .../src/series/implementations/categorical.rs | 4 ++-- .../src/series/implementations/dates_time.rs | 4 ++-- .../src/series/implementations/datetime.rs | 7 +++--- .../src/series/implementations/decimal.rs | 7 +++--- .../src/series/implementations/duration.rs | 7 +++--- .../src/series/implementations/floats.rs | 4 ++-- .../src/series/implementations/mod.rs | 4 ++-- .../src/series/implementations/string.rs | 4 ++-- .../src/series/implementations/struct_.rs | 22 ++++++++---------- crates/polars-core/src/series/mod.rs | 2 +- crates/polars-core/src/series/series_trait.rs | 4 ++-- .../src/physical_plan/expressions/sort.rs | 4 ++-- .../src/chunked_array/array/namespace.rs | 4 ++-- .../src/chunked_array/list/namespace.rs | 7 +++--- crates/polars-ops/src/series/ops/cut.rs | 2 +- .../src/executors/sinks/sort/sink.rs | 2 +- .../src/dsl/function_expr/array.rs | 2 +- .../polars-plan/src/dsl/function_expr/list.rs | 2 +- crates/polars-sql/tests/iss_8395.rs | 2 +- py-polars/src/series/mod.rs | 8 +++++-- 25 files changed, 83 insertions(+), 56 deletions(-) diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index 589b8996dc62..cbb38f954e5e 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -124,6 +124,29 @@ impl ArrayChunked { .collect_ca_with_dtype(self.name(), self.dtype().clone()) } + /// Try apply a closure `F` to each array. + /// + /// # Safety + /// Return series of `F` must has the same dtype and number of elements as input if it is Ok. + pub unsafe fn try_apply_amortized_same_type<'a, F>(&'a self, mut f: F) -> PolarsResult + where + F: FnMut(UnstableSeries<'a>) -> PolarsResult, + { + if self.is_empty() { + return Ok(self.clone()); + } + self.amortized_iter() + .map(|opt_v| { + opt_v + .map(|v| { + let out = f(v)?; + Ok(to_arr(&out)) + }) + .transpose() + }) + .try_collect_ca_with_dtype(self.name(), self.dtype().clone()) + } + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. /// /// # Safety diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 019fb6c342d8..0d8a696d190a 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1843,7 +1843,7 @@ impl DataFrame { // no need to compute the sort indices and then take by these indices // simply sort and return as frame if df.width() == 1 && df.check_name_to_idx(s.name()).is_ok() { - let mut out = s.sort_with(options); + let mut out = s.sort_with(options)?; if let Some((offset, len)) = slice { out = out.slice(offset, len); } diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 86705e8f9af3..b0da78d5141a 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -173,8 +173,8 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - ChunkSort::sort_with(&self.0, options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index d0a5523c7d8c..8cea50d76212 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -136,8 +136,8 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - ChunkSort::sort_with(&self.0, options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 1aa17d298d3f..06f9f0920243 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -198,8 +198,8 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - ChunkSort::sort_with(&self.0, options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index f9a23c261417..1efd20432d94 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -226,8 +226,8 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - self.0.sort_with(options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index 5f5e993dcbbc..c0516e92ce28 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -266,8 +266,8 @@ macro_rules! impl_dyn_series { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - self.0.sort_with(options).$into_logical().into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self.0.sort_with(options).$into_logical().into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index c4c8bfe1b47b..29accf0ac33e 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -266,11 +266,12 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - self.0 + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self + .0 .sort_with(options) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() + .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index d50befa1059c..1c3e6f61566c 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -261,11 +261,12 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - self.0 + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self + .0 .sort_with(options) .into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() + .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 30e3f30857e0..0dfecd55de9c 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -326,11 +326,12 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - self.0 + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(self + .0 .sort_with(options) .into_duration(self.0.time_unit()) - .into_series() + .into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 3332649da16b..3f566463acfd 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -229,8 +229,8 @@ macro_rules! impl_dyn_series { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - ChunkSort::sort_with(&self.0, options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index ab51d751436a..6f589f1b61a1 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -332,8 +332,8 @@ macro_rules! impl_dyn_series { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - ChunkSort::sort_with(&self.0, options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 9a8c1b1f6aa4..38a60a1be615 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -180,8 +180,8 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } - fn sort_with(&self, options: SortOptions) -> Series { - ChunkSort::sort_with(&self.0, options).into_series() + fn sort_with(&self, options: SortOptions) -> PolarsResult { + Ok(ChunkSort::sort_with(&self.0, options).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index fcf85754aac7..219b88f36f7d 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -313,7 +313,7 @@ impl SeriesTrait for SeriesWrap { &self.0 } - fn sort_with(&self, options: SortOptions) -> Series { + fn sort_with(&self, options: SortOptions) -> PolarsResult { let df = self.0.clone().unnest(); let desc = if options.descending { @@ -321,17 +321,15 @@ impl SeriesTrait for SeriesWrap { } else { vec![false; df.width()] }; - let out = df - .sort_impl( - df.columns.clone(), - desc, - options.nulls_last, - options.maintain_order, - None, - options.multithreaded, - ) - .unwrap(); - StructChunked::new_unchecked(self.name(), &out.columns).into_series() + let out = df.sort_impl( + df.columns.clone(), + desc, + options.nulls_last, + options.maintain_order, + None, + options.multithreaded, + )?; + Ok(StructChunked::new_unchecked(self.name(), &out.columns).into_series()) } fn arg_sort(&self, options: SortOptions) -> IdxCa { diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index b72a83b4a3c9..2eb60f08788b 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -285,7 +285,7 @@ impl Series { Ok(self) } - pub fn sort(&self, descending: bool, nulls_last: bool) -> Self { + pub fn sort(&self, descending: bool, nulls_last: bool) -> PolarsResult { self.sort_with(SortOptions { descending, nulls_last, diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index b97ffc864257..ff4a7eca65a6 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -333,8 +333,8 @@ pub trait SeriesTrait: invalid_operation_panic!(get_unchecked, self) } - fn sort_with(&self, _options: SortOptions) -> Series { - invalid_operation_panic!(sort_with, self) + fn sort_with(&self, _options: SortOptions) -> PolarsResult { + polars_bail!(opq = sort_with, self._dtype()); } /// Retrieve the indexes needed for a sort. diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 0df7d4b94ab9..207ad3c82915 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -51,7 +51,7 @@ impl PhysicalExpr for SortExpr { } fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { let series = self.physical_expr.evaluate(df, state)?; - Ok(series.sort_with(self.options)) + series.sort_with(self.options) } #[allow(clippy::ptr_arg)] @@ -65,7 +65,7 @@ impl PhysicalExpr for SortExpr { match ac.agg_state() { AggState::AggregatedList(s) => { let ca = s.list().unwrap(); - let out = ca.lst_sort(self.options); + let out = ca.lst_sort(self.options)?; ac.with_series(out.into_series(), true, Some(&self.expr))?; }, _ => { diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index 42e402f25066..42555ef52080 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -96,10 +96,10 @@ pub trait ArrayNameSpace: AsArray { array_all(ca) } - fn array_sort(&self, options: SortOptions) -> ArrayChunked { + fn array_sort(&self, options: SortOptions) -> PolarsResult { let ca = self.as_array(); // SAFETY: Sort only changes the order of the elements in each subarray. - unsafe { ca.apply_amortized_same_type(|s| s.as_ref().sort_with(options)) } + unsafe { ca.try_apply_amortized_same_type(|s| s.as_ref().sort_with(options)) } } fn array_reverse(&self) -> ArrayChunked { diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 38ca7732c40c..a4f7e78e2c6d 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -242,11 +242,10 @@ pub trait ListNameSpaceImpl: AsList { } } - #[must_use] - fn lst_sort(&self, options: SortOptions) -> ListChunked { + fn lst_sort(&self, options: SortOptions) -> PolarsResult { let ca = self.as_list(); - let out = ca.apply_amortized(|s| s.as_ref().sort_with(options)); - self.same_type(out) + let out = ca.try_apply_amortized(|s| s.as_ref().sort_with(options))?; + Ok(self.same_type(out)) } #[must_use] diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index b7e87a23d8a8..3b7d10dcb0a4 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -103,7 +103,7 @@ pub fn qcut( include_breaks: bool, ) -> PolarsResult { let s = s.cast(&DataType::Float64)?; - let s2 = s.sort(false, false); + let s2 = s.sort(false, false)?; let ca = s2.f64()?; if ca.null_count() == ca.len() { diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 1411a76c872d..01138240b6ce 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -188,7 +188,7 @@ impl Sink for SortSink { nulls_last: self.sort_args.nulls_last, multithreaded: true, maintain_order: self.sort_args.maintain_order, - }); + })?; let instant = self.ooc_start.unwrap(); if context.verbose { diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index a731a8e0c70a..9dc04bd477c7 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -186,7 +186,7 @@ pub(super) fn all(s: &Series) -> PolarsResult { } pub(super) fn sort(s: &Series, options: SortOptions) -> PolarsResult { - Ok(s.array()?.array_sort(options).into_series()) + Ok(s.array()?.array_sort(options)?.into_series()) } pub(super) fn reverse(s: &Series) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index b9dafcf9e305..3fdbf6a18134 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -551,7 +551,7 @@ pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsRes } pub(super) fn sort(s: &Series, options: SortOptions) -> PolarsResult { - Ok(s.list()?.lst_sort(options).into_series()) + Ok(s.list()?.lst_sort(options)?.into_series()) } pub(super) fn reverse(s: &Series) -> PolarsResult { diff --git a/crates/polars-sql/tests/iss_8395.rs b/crates/polars-sql/tests/iss_8395.rs index bc20f1e448f6..06d74affb33e 100644 --- a/crates/polars-sql/tests/iss_8395.rs +++ b/crates/polars-sql/tests/iss_8395.rs @@ -19,7 +19,7 @@ fn iss_8395() -> PolarsResult<()> { let df = res.collect()?; // assert that the df only contains [vegetables, seafood] - let s = df.column("category")?.unique()?.sort(false, false); + let s = df.column("category")?.unique()?.sort(false, false)?; let expected = Series::new("category", &["seafood", "vegetables"]); assert!(s.equals(&expected)); Ok(()) diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 9322eaa9d93a..b3175bc9e731 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -283,8 +283,12 @@ impl PySeries { } } - fn sort(&mut self, descending: bool, nulls_last: bool) -> Self { - self.series.sort(descending, nulls_last).into() + fn sort(&mut self, descending: bool, nulls_last: bool) -> PyResult { + Ok(self + .series + .sort(descending, nulls_last) + .map_err(PyPolarsErr::from)? + .into()) } fn take_with_series(&self, indices: &PySeries) -> PyResult { From 31bb26f97c18514406bb70d81a84dc998560b25a Mon Sep 17 00:00:00 2001 From: Kevin Lim Date: Fri, 29 Mar 2024 01:24:09 -0700 Subject: [PATCH 10/30] fix(python): fix panic when doing a scan_parquet with hive partioning (#15381) Co-authored-by: ritchie --- .../src/logical_plan/optimizer/cache_states.rs | 3 ++- crates/polars-plan/src/logical_plan/optimizer/mod.rs | 9 ++++++++- py-polars/tests/unit/io/test_hive.py | 6 ++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index 6f7a865daade..04d9fb3d773c 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -114,6 +114,7 @@ pub(super) fn set_cache_states( lp_arena: &mut Arena, expr_arena: &mut Arena, scratch: &mut Vec, + hive_partition_eval: HiveEval<'_>, verbose: bool, ) -> PolarsResult<()> { let mut stack = Vec::with_capacity(4); @@ -262,7 +263,7 @@ pub(super) fn set_cache_states( // back to the cache node again if !cache_schema_and_children.is_empty() { let mut proj_pd = ProjectionPushDown::new(); - let pred_pd = PredicatePushDown::new(Default::default()).block_at_cache(false); + let pred_pd = PredicatePushDown::new(hive_partition_eval).block_at_cache(false); for (_cache_id, v) in cache_schema_and_children { // # CHECK IF WE NEED TO REMOVE CACHES // If we encounter multiple predicates we remove the cache nodes completely as we don't diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index 1ba7d8befd85..3231a730edee 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -187,7 +187,14 @@ pub fn optimize( lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?; if members.has_joins_or_unions && members.has_cache { - cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, verbose)?; + cache_states::set_cache_states( + lp_top, + lp_arena, + expr_arena, + scratch, + hive_partition_eval, + verbose, + )?; } // This one should run (nearly) last as this modifies the projections diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 67ddef655366..d8f4a81beddb 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -88,6 +88,12 @@ def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files( assert q.filter(pl.col("a").is_in([1, 4])).collect().shape == (2, 2) assert "hive partitioning: skipped 3 files" in capfd.readouterr().err + # Ensure the CSE can work with hive partitions. + q = q.filter(pl.col("a").gt(2)) + assert q.join(q, on="a", how="left").collect(comm_subplan_elim=True).to_dict( + as_series=False + ) == {"d": [3, 4], "a": [3, 4], "d_right": [3, 4]} + @pytest.mark.skip( reason="Broken by pyarrow 15 release: https://github.com/pola-rs/polars/issues/13892" From 4a6c179aa81450b5e3b1ee975d0aa78fe898ca1d Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 29 Mar 2024 10:12:55 +0100 Subject: [PATCH 11/30] fix: Conversion of expr_ir in partition fast path (#15388) --- crates/polars-lazy/src/physical_plan/planner/lp.rs | 4 ++-- .../src/physical_plan/streaming/convert_alp.rs | 10 ++++------ crates/polars-plan/src/logical_plan/expr_ir.rs | 4 ++-- py-polars/polars/testing/__init__.py | 1 + py-polars/polars/testing/_constants.py | 2 ++ .../tests/unit/streaming/test_streaming_group_by.py | 8 ++++++++ 6 files changed, 19 insertions(+), 10 deletions(-) create mode 100644 py-polars/polars/testing/_constants.py diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 68cfbc09ad2a..f2ccf79706d0 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -431,11 +431,11 @@ pub fn create_physical_plan( let input = create_physical_plan(input, lp_arena, expr_arena)?; let keys = keys .iter() - .map(|e| node_to_expr(e.node(), expr_arena)) + .map(|e| e.to_expr(expr_arena)) .collect::>(); let aggs = aggs .iter() - .map(|e| node_to_expr(e.node(), expr_arena)) + .map(|e| e.to_expr(expr_arena)) .collect::>(); Ok(Box::new(executors::PartitionGroupByExec::new( input, diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 8b2727dc3d45..acd293dead12 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -386,7 +386,7 @@ pub(crate) fn insert_streaming_nodes( aggs, maintain_order: false, apply: None, - schema, + schema: output_schema, options, .. } => { @@ -435,17 +435,15 @@ pub(crate) fn insert_streaming_nodes( let valid_key = || { keys.iter().all(|e| { - expr_arena - .get(e.node()) - .get_type(schema, Context::Default, expr_arena) - // ensure we don't group_by list + output_schema + .get(e.output_name()) .map(|dt| !matches!(dt, DataType::List(_))) .unwrap_or(false) }) }; let valid_types = || { - schema + output_schema .iter_dtypes() .all(|dt| allowed_dtype(dt, string_cache)) }; diff --git a/crates/polars-plan/src/logical_plan/expr_ir.rs b/crates/polars-plan/src/logical_plan/expr_ir.rs index 7a1b47d2cb57..f5f0b0ddf12a 100644 --- a/crates/polars-plan/src/logical_plan/expr_ir.rs +++ b/crates/polars-plan/src/logical_plan/expr_ir.rs @@ -94,11 +94,11 @@ impl ExprIR { self.output_name.unwrap() } - pub(crate) fn output_name(&self) -> &str { + pub fn output_name(&self) -> &str { self.output_name_arc().as_ref() } - pub(crate) fn to_expr(&self, expr_arena: &Arena) -> Expr { + pub fn to_expr(&self, expr_arena: &Arena) -> Expr { let out = node_to_expr(self.node, expr_arena); match &self.output_name { diff --git a/py-polars/polars/testing/__init__.py b/py-polars/polars/testing/__init__.py index b5962f7fba2c..06b4f6c91419 100644 --- a/py-polars/polars/testing/__init__.py +++ b/py-polars/polars/testing/__init__.py @@ -10,4 +10,5 @@ "assert_frame_not_equal", "assert_series_equal", "assert_series_not_equal", + "_constants", ] diff --git a/py-polars/polars/testing/_constants.py b/py-polars/polars/testing/_constants.py new file mode 100644 index 000000000000..8c11b6d0f176 --- /dev/null +++ b/py-polars/polars/testing/_constants.py @@ -0,0 +1,2 @@ +# On this limit Polars will start partitioning in debug builds +PARTITION_LIMIT = 15 diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index 506422dc38dd..e7915115b79a 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -8,6 +8,7 @@ import polars as pl from polars.testing import assert_frame_equal +from polars.testing._constants import PARTITION_LIMIT if TYPE_CHECKING: from pathlib import Path @@ -480,3 +481,10 @@ def test_streaming_groupby_binary_15116() -> None: "str": [b"A", b"BB", b"CCCC", b"DDDDDDDD", b"EEEEEEEEEEEEEEEE"], "count": [3, 2, 2, 2, 1], } + + +def test_streaming_group_by_convert_15380() -> None: + assert ( + pl.DataFrame({"a": [1] * PARTITION_LIMIT}).group_by(b="a").len()["len"].item() + == PARTITION_LIMIT + ) From 7216f9cb9351ba2892416a77555ac38d4a17027a Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 29 Mar 2024 10:59:18 +0100 Subject: [PATCH 12/30] fix: Add FixedSizeBinary to arrow field conversion (#15389) --- crates/polars-core/src/datatypes/field.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index 8b0664b14168..009e88fee303 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -173,6 +173,7 @@ impl DataType { DataType::BinaryOffset } }, + ArrowDataType::FixedSizeBinary(_) => DataType::Binary, dt => panic!("Arrow datatype {dt:?} not supported by Polars. You probably need to activate that data-type feature."), } } From cf4df740cabb5662866d1e9f450b5368aafda2ae Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 29 Mar 2024 13:58:30 +0100 Subject: [PATCH 13/30] refactor: use recursive crate, add missing recursive tag (#15393) --- Cargo.lock | 21 + Cargo.toml | 1 + crates/polars-plan/Cargo.toml | 1 + .../src/logical_plan/aexpr/schema.rs | 3 + .../src/logical_plan/conversion.rs | 7 +- crates/polars-plan/src/logical_plan/mod.rs | 47 +- .../optimizer/predicate_pushdown/mod.rs | 862 +++++++++--------- .../optimizer/projection_pushdown/mod.rs | 650 +++++++------ .../optimizer/slice_pushdown_lp.rs | 459 +++++----- .../src/logical_plan/visitor/visitors.rs | 60 +- crates/polars-utils/src/lib.rs | 1 - crates/polars-utils/src/recursion.rs | 6 - 12 files changed, 1057 insertions(+), 1061 deletions(-) delete mode 100644 crates/polars-utils/src/recursion.rs diff --git a/Cargo.lock b/Cargo.lock index d0692b3eef05..93c1a6c43eaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2955,6 +2955,7 @@ dependencies = [ "polars-utils", "pyo3", "rayon", + "recursive", "regex", "serde", "smartstring", @@ -3333,6 +3334,26 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "recursive" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11e2368fb7f6b4a62aab303ee5c24e79e0a0d120e52348457d41a4cbf21ed2e0" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad4a66026ca69c02b6807af930ac445871e9af1b1037e7a37064842851fc4233" +dependencies = [ + "quote", + "syn 2.0.55", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index e06e28faf4d9..7d75530dc2e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ rayon = "1.9" regex = "1.9" reqwest = { version = "0.11", default-features = false } ryu = "1.0.13" +recursive = "0.1" serde = "1.0.188" serde_json = "1" simd-json = { version = "0.13", features = ["known-key"] } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 27e84085e6d7..7a09c8c53173 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -33,6 +33,7 @@ once_cell = { workspace = true } percent-encoding = { workspace = true } pyo3 = { workspace = true, optional = true } rayon = { workspace = true } +recursive = { workspace = true } regex = { workspace = true, optional = true } serde = { workspace = true, features = ["derive", "rc"], optional = true } smartstring = { workspace = true } diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 6d8e692f96f3..b2829bf74334 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -1,3 +1,5 @@ +use recursive::recursive; + use super::*; fn float_type(field: &mut Field) { @@ -10,6 +12,7 @@ fn float_type(field: &mut Field) { impl AExpr { /// Get Field result of the expression. The schema is the input data. + #[recursive] pub fn to_field( &self, schema: &Schema, diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index fa3d77290861..1cfb23b79bc4 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -1,5 +1,6 @@ use polars_core::prelude::*; use polars_utils::vec::ConvertVec; +use recursive::recursive; use crate::constants::get_len_name; use crate::prelude::*; @@ -62,7 +63,8 @@ fn to_aexprs(input: Vec, arena: &mut Arena, state: &mut ConversionS .collect() } -/// converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation +/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation. +#[recursive] fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionState) -> Node { let v = match expr { Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(*expr, arena, state)), @@ -261,6 +263,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta /// converts LogicalPlan to ALogicalPlan /// it adds expressions & lps to the respective arenas as it traverses the plan /// finally it returns the top node of the logical plan +#[recursive] pub fn to_alp( lp: LogicalPlan, expr_arena: &mut Arena, @@ -474,6 +477,7 @@ pub fn to_alp( } /// converts a node from the AExpr arena to Expr +#[recursive] pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = expr_arena.get(node).clone(); @@ -705,6 +709,7 @@ fn expr_irs_to_exprs(expr_irs: Vec, expr_arena: &Arena) -> Vec( self, conversion_fn: &F, diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index 8f119eca4eef..1b832b03ec57 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use std::sync::{Arc, Mutex}; use polars_core::prelude::*; -use polars_utils::recursion::with_dynamic_stack; +use recursive::recursive; use crate::logical_plan::LogicalPlan::DataFrameScan; use crate::prelude::*; @@ -250,30 +250,29 @@ impl Clone for LogicalPlan { // calls clone on every member of every enum variant. #[rustfmt::skip] #[allow(clippy::clone_on_copy)] + #[recursive] fn clone(&self) -> Self { - with_dynamic_stack(|| { - match self { - #[cfg(feature = "python")] - Self::PythonScan { options } => Self::PythonScan { options: options.clone() }, - Self::Selection { input, predicate } => Self::Selection { input: input.clone(), predicate: predicate.clone() }, - Self::Cache { input, id, cache_hits } => Self::Cache { input: input.clone(), id: id.clone(), cache_hits: cache_hits.clone() }, - Self::Scan { paths, file_info, predicate, file_options, scan_type } => Self::Scan { paths: paths.clone(), file_info: file_info.clone(), predicate: predicate.clone(), file_options: file_options.clone(), scan_type: scan_type.clone() }, - Self::DataFrameScan { df, schema, output_schema, projection, selection } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), output_schema: output_schema.clone(), projection: projection.clone(), selection: selection.clone() }, - Self::Projection { expr, input, schema, options } => Self::Projection { expr: expr.clone(), input: input.clone(), schema: schema.clone(), options: options.clone() }, - Self::Aggregate { input, keys, aggs, schema, apply, maintain_order, options } => Self::Aggregate { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), schema: schema.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() }, - Self::Join { input_left, input_right, schema, left_on, right_on, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), schema: schema.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone() }, - Self::HStack { input, exprs, schema, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), schema: schema.clone(), options: options.clone() }, - Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() }, - Self::Sort { input, by_column, args } => Self::Sort { input: input.clone(), by_column: by_column.clone(), args: args.clone() }, - Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() }, - Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() }, - Self::Union { inputs, options } => Self::Union { inputs: inputs.clone(), options: options.clone() }, - Self::HConcat { inputs, schema, options } => Self::HConcat { inputs: inputs.clone(), schema: schema.clone(), options: options.clone() }, - Self::Error { input, err } => Self::Error { input: input.clone(), err: err.clone() }, - Self::ExtContext { input, contexts, schema } => Self::ExtContext { input: input.clone(), contexts: contexts.clone(), schema: schema.clone() }, - Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() }, - } - }) + match self { + #[cfg(feature = "python")] + Self::PythonScan { options } => Self::PythonScan { options: options.clone() }, + Self::Selection { input, predicate } => Self::Selection { input: input.clone(), predicate: predicate.clone() }, + Self::Cache { input, id, cache_hits } => Self::Cache { input: input.clone(), id: id.clone(), cache_hits: cache_hits.clone() }, + Self::Scan { paths, file_info, predicate, file_options, scan_type } => Self::Scan { paths: paths.clone(), file_info: file_info.clone(), predicate: predicate.clone(), file_options: file_options.clone(), scan_type: scan_type.clone() }, + Self::DataFrameScan { df, schema, output_schema, projection, selection } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), output_schema: output_schema.clone(), projection: projection.clone(), selection: selection.clone() }, + Self::Projection { expr, input, schema, options } => Self::Projection { expr: expr.clone(), input: input.clone(), schema: schema.clone(), options: options.clone() }, + Self::Aggregate { input, keys, aggs, schema, apply, maintain_order, options } => Self::Aggregate { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), schema: schema.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() }, + Self::Join { input_left, input_right, schema, left_on, right_on, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), schema: schema.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone() }, + Self::HStack { input, exprs, schema, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), schema: schema.clone(), options: options.clone() }, + Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() }, + Self::Sort { input, by_column, args } => Self::Sort { input: input.clone(), by_column: by_column.clone(), args: args.clone() }, + Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() }, + Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() }, + Self::Union { inputs, options } => Self::Union { inputs: inputs.clone(), options: options.clone() }, + Self::HConcat { inputs, schema, options } => Self::HConcat { inputs: inputs.clone(), schema: schema.clone(), options: options.clone() }, + Self::Error { input, err } => Self::Error { input: input.clone(), err: err.clone() }, + Self::ExtContext { input, contexts, schema } => Self::ExtContext { input: input.clone(), contexts: contexts.clone(), schema: schema.clone() }, + Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() }, + } } } diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index b533436c6650..679b58740b1c 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -6,7 +6,7 @@ mod utils; use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; -use polars_utils::recursion::with_dynamic_stack; +use recursive::recursive; use utils::*; use super::*; @@ -237,6 +237,7 @@ impl<'a> PredicatePushDown<'a> { /// The `Node`s are indexes in the `expr_arena` /// * `lp_arena` - The local memory arena for the logical plan. /// * `expr_arena` - The local memory arena for the expressions. + #[recursive] fn push_down( &self, lp: ALogicalPlan, @@ -244,489 +245,466 @@ impl<'a> PredicatePushDown<'a> { lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - with_dynamic_stack(|| { - use ALogicalPlan::*; - - match lp { - Selection { - ref predicate, - input, - } => { - // Use a tmp_key to avoid inadvertently combining predicates that otherwise would have - // been partially pushed: - // - // (1) .filter(pl.count().over("key") == 1) - // (2) .filter(pl.col("key") == 1) - // - // (2) can be pushed past (1) but they both have the same predicate - // key name in the hashtable. - let tmp_key = Arc::::from(&*temporary_unique_key(&acc_predicates)); - acc_predicates.insert(tmp_key.clone(), predicate.clone()); - - let local_predicates = - match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 { - PushdownEligibility::Full => vec![], - PushdownEligibility::Partial { to_local } => { - let mut out = Vec::with_capacity(to_local.len()); - for key in to_local { - out.push(acc_predicates.remove(&key).unwrap()); - } - out - }, - PushdownEligibility::NoPushdown => { - let out = acc_predicates.drain().map(|t| t.1).collect(); - acc_predicates.clear(); - out - }, - }; + use ALogicalPlan::*; + + match lp { + Selection { + ref predicate, + input, + } => { + // Use a tmp_key to avoid inadvertently combining predicates that otherwise would have + // been partially pushed: + // + // (1) .filter(pl.count().over("key") == 1) + // (2) .filter(pl.col("key") == 1) + // + // (2) can be pushed past (1) but they both have the same predicate + // key name in the hashtable. + let tmp_key = Arc::::from(&*temporary_unique_key(&acc_predicates)); + acc_predicates.insert(tmp_key.clone(), predicate.clone()); + + let local_predicates = + match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 { + PushdownEligibility::Full => vec![], + PushdownEligibility::Partial { to_local } => { + let mut out = Vec::with_capacity(to_local.len()); + for key in to_local { + out.push(acc_predicates.remove(&key).unwrap()); + } + out + }, + PushdownEligibility::NoPushdown => { + let out = acc_predicates.drain().map(|t| t.1).collect(); + acc_predicates.clear(); + out + }, + }; - if let Some(predicate) = acc_predicates.remove(&tmp_key) { - insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena); - } + if let Some(predicate) = acc_predicates.remove(&tmp_key) { + insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena); + } - let alp = lp_arena.take(input); - let new_input = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?; + let alp = lp_arena.take(input); + let new_input = self.push_down(alp, acc_predicates, lp_arena, expr_arena)?; - // TODO! - // If a predicates result would be influenced by earlier applied - // predicates, we simply don't pushdown this one passed this node - // However, we can do better and let it pass but store the order of the predicates - // so that we can apply them in correct order at the deepest level - Ok(self.optional_apply_predicate( + // TODO! + // If a predicates result would be influenced by earlier applied + // predicates, we simply don't pushdown this one passed this node + // However, we can do better and let it pass but store the order of the predicates + // so that we can apply them in correct order at the deepest level + Ok( + self.optional_apply_predicate( new_input, local_predicates, lp_arena, expr_arena, - )) - }, - DataFrameScan { + ), + ) + }, + DataFrameScan { + df, + schema, + output_schema, + projection, + selection, + } => { + let selection = predicate_at_scan(acc_predicates, selection, expr_arena); + let lp = DataFrameScan { df, schema, output_schema, projection, selection, - } => { - let selection = predicate_at_scan(acc_predicates, selection, expr_arena); - let lp = DataFrameScan { - df, - schema, - output_schema, - projection, - selection, - }; - Ok(lp) - }, - Scan { - mut paths, - mut file_info, - ref predicate, - mut scan_type, - file_options: options, - output_schema, - } => { - for e in acc_predicates.values() { - debug_assert_aexpr_allows_predicate_pushdown(e.node(), expr_arena); - } + }; + Ok(lp) + }, + Scan { + mut paths, + mut file_info, + ref predicate, + mut scan_type, + file_options: options, + output_schema, + } => { + for e in acc_predicates.values() { + debug_assert_aexpr_allows_predicate_pushdown(e.node(), expr_arena); + } - let local_predicates = match &scan_type { - #[cfg(feature = "parquet")] - FileScan::Parquet { .. } => vec![], - #[cfg(feature = "ipc")] - FileScan::Ipc { .. } => vec![], - _ => { - // Disallow row index pushdown of other scans as they may - // not update the row index properly before applying the - // predicate (e.g. FileScan::Csv doesn't). - if let Some(ref row_index) = options.row_index { - let row_index_predicates = transfer_to_local_by_name( - expr_arena, - &mut acc_predicates, - |name| name.as_ref() == row_index.name, - ); - row_index_predicates - } else { - vec![] - } - }, - }; - let predicate = - predicate_at_scan(acc_predicates, predicate.clone(), expr_arena); - - if let (true, Some(predicate)) = (file_info.hive_parts.is_some(), &predicate) { - if let Some(io_expr) = - self.hive_partition_eval.unwrap()(predicate, expr_arena) - { - if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { - let mut new_paths = Vec::with_capacity(paths.len()); - - for path in paths.as_ref().iter() { - file_info.update_hive_partitions(path)?; - let hive_part_stats = file_info.hive_parts.as_deref().ok_or_else(|| polars_err!(ComputeError: "cannot combine hive partitioned directories with non-hive partitioned ones"))?; - - if stats_evaluator - .should_read(hive_part_stats.get_statistics())? - { - new_paths.push(path.clone()); - } + let local_predicates = match &scan_type { + #[cfg(feature = "parquet")] + FileScan::Parquet { .. } => vec![], + #[cfg(feature = "ipc")] + FileScan::Ipc { .. } => vec![], + _ => { + // Disallow row index pushdown of other scans as they may + // not update the row index properly before applying the + // predicate (e.g. FileScan::Csv doesn't). + if let Some(ref row_index) = options.row_index { + let row_index_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + |name| name.as_ref() == row_index.name, + ); + row_index_predicates + } else { + vec![] + } + }, + }; + let predicate = predicate_at_scan(acc_predicates, predicate.clone(), expr_arena); + + if let (true, Some(predicate)) = (file_info.hive_parts.is_some(), &predicate) { + if let Some(io_expr) = self.hive_partition_eval.unwrap()(predicate, expr_arena) + { + if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { + let mut new_paths = Vec::with_capacity(paths.len()); + + for path in paths.as_ref().iter() { + file_info.update_hive_partitions(path)?; + let hive_part_stats = file_info.hive_parts.as_deref().ok_or_else(|| polars_err!(ComputeError: "cannot combine hive partitioned directories with non-hive partitioned ones"))?; + + if stats_evaluator.should_read(hive_part_stats.get_statistics())? { + new_paths.push(path.clone()); } + } - if paths.len() != new_paths.len() { - if self.verbose { - eprintln!( - "hive partitioning: skipped {} files, first file : {}", - paths.len() - new_paths.len(), - paths[0].display() - ) - } - scan_type.remove_metadata(); - } - if paths.is_empty() { - let schema = - output_schema.as_ref().unwrap_or(&file_info.schema); - let df = DataFrame::from(schema.as_ref()); - - return Ok(DataFrameScan { - df: Arc::new(df), - schema: schema.clone(), - output_schema: None, - projection: None, - selection: None, - }); - } else { - paths = Arc::from(new_paths) + if paths.len() != new_paths.len() { + if self.verbose { + eprintln!( + "hive partitioning: skipped {} files, first file : {}", + paths.len() - new_paths.len(), + paths[0].display() + ) } + scan_type.remove_metadata(); + } + if paths.is_empty() { + let schema = output_schema.as_ref().unwrap_or(&file_info.schema); + let df = DataFrame::from(schema.as_ref()); + + return Ok(DataFrameScan { + df: Arc::new(df), + schema: schema.clone(), + output_schema: None, + projection: None, + selection: None, + }); + } else { + paths = Arc::from(new_paths) } } } + } - let mut do_optimization = match &scan_type { - #[cfg(feature = "csv")] - FileScan::Csv { .. } => options.n_rows.is_none(), - FileScan::Anonymous { function, .. } => { - function.allows_predicate_pushdown() - }, - #[allow(unreachable_patterns)] - _ => true, + let mut do_optimization = match &scan_type { + #[cfg(feature = "csv")] + FileScan::Csv { .. } => options.n_rows.is_none(), + FileScan::Anonymous { function, .. } => function.allows_predicate_pushdown(), + #[allow(unreachable_patterns)] + _ => true, + }; + do_optimization &= predicate.is_some(); + + let lp = if do_optimization { + Scan { + paths, + file_info, + predicate, + file_options: options, + output_schema, + scan_type, + } + } else { + let lp = Scan { + paths, + file_info, + predicate: None, + file_options: options, + output_schema, + scan_type, }; - do_optimization &= predicate.is_some(); - - let lp = if do_optimization { - Scan { - paths, - file_info, - predicate, - file_options: options, - output_schema, - scan_type, - } + if let Some(predicate) = predicate { + let input = lp_arena.add(lp); + Selection { input, predicate } } else { - let lp = Scan { - paths, - file_info, - predicate: None, - file_options: options, - output_schema, - scan_type, - }; - if let Some(predicate) = predicate { - let input = lp_arena.add(lp); - Selection { input, predicate } - } else { - lp - } - }; + lp + } + }; + + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + }, + Distinct { input, options } => { + if let Some(ref subset) = options.subset { + // Predicates on the subset can pass. + let subset = subset.clone(); + let mut names_set = PlHashSet::<&str>::with_capacity(subset.len()); + for name in subset.iter() { + names_set.insert(name.as_str()); + } + let condition = |name: Arc| !names_set.contains(name.as_ref()); + let local_predicates = + transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); + + self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?; + let lp = Distinct { input, options }; Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - }, - Distinct { input, options } => { - if let Some(ref subset) = options.subset { - // Predicates on the subset can pass. - let subset = subset.clone(); - let mut names_set = PlHashSet::<&str>::with_capacity(subset.len()); - for name in subset.iter() { - names_set.insert(name.as_str()); - } + } else { + let lp = Distinct { input, options }; + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + } + }, + Join { + input_left, + input_right, + left_on, + right_on, + schema, + options, + } => process_join( + self, + lp_arena, + expr_arena, + input_left, + input_right, + left_on, + right_on, + schema, + options, + acc_predicates, + ), + MapFunction { ref function, .. } => { + if function.allow_predicate_pd() { + match function { + FunctionNode::Rename { existing, new, .. } => { + let local_predicates = + process_rename(&mut acc_predicates, expr_arena, existing, new)?; + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + FunctionNode::Explode { columns, .. } => { + let condition = + |name: Arc| columns.iter().any(|s| s.as_ref() == &*name); - let condition = |name: Arc| !names_set.contains(name.as_ref()); - let local_predicates = - transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); + // first columns that refer to the exploded columns should be done here + let local_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + condition, + ); - self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?; - let lp = Distinct { input, options }; - Ok(self.optional_apply_predicate( + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + FunctionNode::Melt { args, .. } => { + let variable_name = args.variable_name.as_deref().unwrap_or("variable"); + let value_name = args.value_name.as_deref().unwrap_or("value"); + + // predicates that will be done at this level + let condition = |name: Arc| { + let name = &*name; + name == variable_name + || name == value_name + || args.value_vars.iter().any(|s| s.as_str() == name) + }; + let local_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + condition, + ); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + _ => self.pushdown_and_continue( lp, - local_predicates, + acc_predicates, lp_arena, expr_arena, - )) + false, + ), + } + } else { + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + } + }, + Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + } => process_group_by( + self, + lp_arena, + expr_arena, + input, + keys, + aggs, + schema, + maintain_order, + apply, + options, + acc_predicates, + ), + lp @ Union { .. } => { + let mut local_predicates = vec![]; + + // a count is influenced by a Union/Vstack + acc_predicates.retain(|_, predicate| { + if has_aexpr(predicate.node(), expr_arena, |ae| matches!(ae, AExpr::Len)) { + local_predicates.push(predicate.clone()); + false } else { - let lp = Distinct { input, options }; - self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + true } - }, - Join { - input_left, - input_right, - left_on, - right_on, - schema, - options, - } => process_join( - self, - lp_arena, - expr_arena, - input_left, - input_right, - left_on, - right_on, - schema, - options, - acc_predicates, - ), - MapFunction { ref function, .. } => { - if function.allow_predicate_pd() { - match function { - FunctionNode::Rename { existing, new, .. } => { - let local_predicates = - process_rename(&mut acc_predicates, expr_arena, existing, new)?; - let lp = self.pushdown_and_continue( - lp, - acc_predicates, - lp_arena, - expr_arena, - false, - )?; - Ok(self.optional_apply_predicate( - lp, - local_predicates, - lp_arena, - expr_arena, - )) - }, - FunctionNode::Explode { columns, .. } => { - let condition = - |name: Arc| columns.iter().any(|s| s.as_ref() == &*name); - - // first columns that refer to the exploded columns should be done here - let local_predicates = transfer_to_local_by_name( - expr_arena, - &mut acc_predicates, - condition, - ); + }); + let lp = + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + }, + lp @ Sort { .. } => { + let mut local_predicates = vec![]; + acc_predicates.retain(|_, predicate| { + if predicate_is_sort_boundary(predicate.node(), expr_arena) { + local_predicates.push(predicate.clone()); + false + } else { + true + } + }); + let lp = + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + }, + // Pushed down passed these nodes + lp @ Sink { .. } => { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) + }, + lp @ HStack { .. } + | lp @ Projection { .. } + | lp @ SimpleProjection { .. } + | lp @ ExtContext { .. } => { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) + }, + // NOT Pushed down passed these nodes + // predicates influence slice sizes + lp @ Slice { .. } => { + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + }, + lp @ HConcat { .. } => { + self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) + }, + // Caches will run predicate push-down in the `cache_states` run. + Cache { .. } => { + if self.block_at_cache { + self.no_pushdown(lp, acc_predicates, lp_arena, expr_arena) + } else { + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) + } + }, + #[cfg(feature = "python")] + PythonScan { + mut options, + predicate, + } => { + if options.pyarrow { + let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); + + if let Some(predicate) = predicate.clone() { + // simplify expressions before we translate them to pyarrow + let lp = PythonScan { + options: options.clone(), + predicate: Some(predicate), + }; + let lp_top = lp_arena.add(lp); + let stack_opt = StackOptimizer {}; + let lp_top = stack_opt + .optimize_loop( + &mut [Box::new(SimplifyExprRule {})], + expr_arena, + lp_arena, + lp_top, + ) + .unwrap(); + let PythonScan { + options: _, + predicate: Some(predicate), + } = lp_arena.take(lp_top) + else { + unreachable!() + }; - let lp = self.pushdown_and_continue( - lp, - acc_predicates, - lp_arena, - expr_arena, - false, - )?; - Ok(self.optional_apply_predicate( - lp, - local_predicates, - lp_arena, - expr_arena, - )) - }, - FunctionNode::Melt { args, .. } => { - let variable_name = - args.variable_name.as_deref().unwrap_or("variable"); - let value_name = args.value_name.as_deref().unwrap_or("value"); - - // predicates that will be done at this level - let condition = |name: Arc| { - let name = &*name; - name == variable_name - || name == value_name - || args.value_vars.iter().any(|s| s.as_str() == name) + match super::super::pyarrow::predicate_to_pa( + predicate.node(), + expr_arena, + Default::default(), + ) { + // we we able to create a pyarrow string, mutate the options + Some(eval_str) => options.predicate = Some(eval_str), + // we were not able to translate the predicate + // apply here + None => { + let lp = PythonScan { + options, + predicate: None, }; - let local_predicates = transfer_to_local_by_name( - expr_arena, - &mut acc_predicates, - condition, - ); - - let lp = self.pushdown_and_continue( - lp, - acc_predicates, - lp_arena, - expr_arena, - false, - )?; - Ok(self.optional_apply_predicate( + return Ok(self.optional_apply_predicate( lp, - local_predicates, + vec![predicate], lp_arena, expr_arena, - )) + )); }, - _ => self.pushdown_and_continue( - lp, - acc_predicates, - lp_arena, - expr_arena, - false, - ), } - } else { - self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } - }, - Aggregate { - input, - keys, - aggs, - schema, - apply, - maintain_order, - options, - } => process_group_by( - self, - lp_arena, - expr_arena, - input, - keys, - aggs, - schema, - maintain_order, - apply, - options, - acc_predicates, - ), - lp @ Union { .. } => { - let mut local_predicates = vec![]; - - // a count is influenced by a Union/Vstack - acc_predicates.retain(|_, predicate| { - if has_aexpr(predicate.node(), expr_arena, |ae| matches!(ae, AExpr::Len)) { - local_predicates.push(predicate.clone()); - false - } else { - true - } - }); - let lp = self.pushdown_and_continue( - lp, - acc_predicates, - lp_arena, - expr_arena, - false, - )?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - }, - lp @ Sort { .. } => { - let mut local_predicates = vec![]; - acc_predicates.retain(|_, predicate| { - if predicate_is_sort_boundary(predicate.node(), expr_arena) { - local_predicates.push(predicate.clone()); - false - } else { - true - } - }); - let lp = self.pushdown_and_continue( - lp, + Ok(PythonScan { options, predicate }) + } else { + self.no_pushdown_restart_opt( + PythonScan { options, predicate }, acc_predicates, lp_arena, expr_arena, - false, - )?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - }, - // Pushed down passed these nodes - lp @ Sink { .. } => { - self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) - }, - lp @ HStack { .. } - | lp @ Projection { .. } - | lp @ SimpleProjection { .. } - | lp @ ExtContext { .. } => { - self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) - }, - // NOT Pushed down passed these nodes - // predicates influence slice sizes - lp @ Slice { .. } => { - self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) - }, - lp @ HConcat { .. } => { - self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) - }, - // Caches will run predicate push-down in the `cache_states` run. - Cache { .. } => { - if self.block_at_cache { - self.no_pushdown(lp, acc_predicates, lp_arena, expr_arena) - } else { - self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) - } - }, - #[cfg(feature = "python")] - PythonScan { - mut options, - predicate, - } => { - if options.pyarrow { - let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); - - if let Some(predicate) = predicate.clone() { - // simplify expressions before we translate them to pyarrow - let lp = PythonScan { - options: options.clone(), - predicate: Some(predicate), - }; - let lp_top = lp_arena.add(lp); - let stack_opt = StackOptimizer {}; - let lp_top = stack_opt - .optimize_loop( - &mut [Box::new(SimplifyExprRule {})], - expr_arena, - lp_arena, - lp_top, - ) - .unwrap(); - let PythonScan { - options: _, - predicate: Some(predicate), - } = lp_arena.take(lp_top) - else { - unreachable!() - }; - - match super::super::pyarrow::predicate_to_pa( - predicate.node(), - expr_arena, - Default::default(), - ) { - // we we able to create a pyarrow string, mutate the options - Some(eval_str) => options.predicate = Some(eval_str), - // we were not able to translate the predicate - // apply here - None => { - let lp = PythonScan { - options, - predicate: None, - }; - return Ok(self.optional_apply_predicate( - lp, - vec![predicate], - lp_arena, - expr_arena, - )); - }, - } - } - Ok(PythonScan { options, predicate }) - } else { - self.no_pushdown_restart_opt( - PythonScan { options, predicate }, - acc_predicates, - lp_arena, - expr_arena, - ) - } - }, - Invalid => unreachable!(), - } - }) + ) + } + }, + Invalid => unreachable!(), + } } pub(crate) fn optimize( diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 6dea5dcc2763..95a0b09de494 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -12,7 +12,7 @@ mod semi_anti_join; use polars_core::datatypes::PlHashSet; use polars_core::prelude::*; use polars_io::RowIndex; -use polars_utils::recursion::with_dynamic_stack; +use recursive::recursive; #[cfg(feature = "semi_anti_join")] use semi_anti_join::process_semi_anti_join; @@ -316,7 +316,7 @@ impl ProjectionPushDown { /// * `projections_seen` - Count the number of projection operations during tree traversal. /// * `lp_arena` - The local memory arena for the logical plan. /// * `expr_arena` - The local memory arena for the expressions. - /// + #[recursive] fn push_down( &mut self, logical_plan: ALogicalPlan, @@ -326,397 +326,395 @@ impl ProjectionPushDown { lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - with_dynamic_stack(|| { - use ALogicalPlan::*; - - match logical_plan { - Projection { expr, input, .. } => process_projection( + use ALogicalPlan::*; + + match logical_plan { + Projection { expr, input, .. } => process_projection( + self, + input, + expr.exprs(), + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), + SimpleProjection { columns, input, .. } => { + let exprs = names_to_expr_irs(columns.iter_names(), expr_arena); + process_projection( self, input, - expr.exprs(), + exprs, acc_projections, projected_names, projections_seen, lp_arena, expr_arena, - ), - SimpleProjection { columns, input, .. } => { - let exprs = names_to_expr_irs(columns.iter_names(), expr_arena); - process_projection( - self, - input, - exprs, - acc_projections, - projected_names, - projections_seen, - lp_arena, + ) + }, + DataFrameScan { + df, + schema, + mut output_schema, + selection, + .. + } => { + let mut projection = None; + if !acc_projections.is_empty() { + output_schema = Some(Arc::new(update_scan_schema( + &acc_projections, expr_arena, - ) - }, - DataFrameScan { + &schema, + false, + )?)); + projection = get_scan_columns(&mut acc_projections, expr_arena, None); + } + let lp = DataFrameScan { df, schema, - mut output_schema, + output_schema, + projection, selection, - .. - } => { - let mut projection = None; - if !acc_projections.is_empty() { - output_schema = Some(Arc::new(update_scan_schema( - &acc_projections, - expr_arena, - &schema, - false, - )?)); - projection = get_scan_columns(&mut acc_projections, expr_arena, None); - } - let lp = DataFrameScan { - df, - schema, - output_schema, - projection, - selection, - }; - Ok(lp) - }, - #[cfg(feature = "python")] - PythonScan { - mut options, - predicate, - } => { - options.with_columns = get_scan_columns(&mut acc_projections, expr_arena, None); + }; + Ok(lp) + }, + #[cfg(feature = "python")] + PythonScan { + mut options, + predicate, + } => { + options.with_columns = get_scan_columns(&mut acc_projections, expr_arena, None); + + options.output_schema = if options.with_columns.is_none() { + None + } else { + Some(Arc::new(update_scan_schema( + &acc_projections, + expr_arena, + &options.schema, + true, + )?)) + }; + Ok(PythonScan { options, predicate }) + }, + Scan { + paths, + file_info, + scan_type, + predicate, + mut file_options, + mut output_schema, + } => { + let mut do_optimization = true; + #[allow(irrefutable_let_patterns)] + if let FileScan::Anonymous { ref function, .. } = scan_type { + do_optimization = function.allows_projection_pushdown(); + } + + if do_optimization { + file_options.with_columns = get_scan_columns( + &mut acc_projections, + expr_arena, + file_options.row_index.as_ref(), + ); - options.output_schema = if options.with_columns.is_none() { + output_schema = if file_options.with_columns.is_none() { None } else { - Some(Arc::new(update_scan_schema( + let mut schema = update_scan_schema( &acc_projections, expr_arena, - &options.schema, - true, - )?)) + &file_info.schema, + scan_type.sort_projection(&file_options), + )?; + // Hive partitions are created AFTER the projection, so the output + // schema is incorrect. Here we ensure the columns that are projected and hive + // parts are added at the proper place in the schema, which is at the end. + if let Some(parts) = file_info.hive_parts.as_deref() { + let partition_schema = parts.schema(); + for (name, _) in partition_schema.iter() { + if let Some(dt) = schema.shift_remove(name) { + schema.with_column(name.clone(), dt); + } + } + } + Some(Arc::new(schema)) }; - Ok(PythonScan { options, predicate }) - }, - Scan { + } + + let lp = Scan { paths, file_info, + output_schema, scan_type, predicate, - mut file_options, - mut output_schema, - } => { - let mut do_optimization = true; - #[allow(irrefutable_let_patterns)] - if let FileScan::Anonymous { ref function, .. } = scan_type { - do_optimization = function.allows_projection_pushdown(); - } - - if do_optimization { - file_options.with_columns = get_scan_columns( + file_options, + }; + Ok(lp) + }, + Sort { + input, + by_column, + args, + } => { + if !acc_projections.is_empty() { + // Make sure that the column(s) used for the sort is projected + by_column.iter().for_each(|node| { + add_expr_to_accumulated( + node.node(), &mut acc_projections, + &mut projected_names, expr_arena, - file_options.row_index.as_ref(), ); + }); + } - output_schema = if file_options.with_columns.is_none() { - None - } else { - let mut schema = update_scan_schema( - &acc_projections, - expr_arena, - &file_info.schema, - scan_type.sort_projection(&file_options), - )?; - // Hive partitions are created AFTER the projection, so the output - // schema is incorrect. Here we ensure the columns that are projected and hive - // parts are added at the proper place in the schema, which is at the end. - if let Some(parts) = file_info.hive_parts.as_deref() { - let partition_schema = parts.schema(); - for (name, _) in partition_schema.iter() { - if let Some(dt) = schema.shift_remove(name) { - schema.with_column(name.clone(), dt); - } - } - } - Some(Arc::new(schema)) - }; - } - - let lp = Scan { - paths, - file_info, - output_schema, - scan_type, - predicate, - file_options, - }; - Ok(lp) - }, - Sort { + self.pushdown_and_assign( + input, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + )?; + Ok(Sort { input, by_column, args, - } => { - if !acc_projections.is_empty() { - // Make sure that the column(s) used for the sort is projected - by_column.iter().for_each(|node| { - add_expr_to_accumulated( - node.node(), + }) + }, + Distinct { input, options } => { + // make sure that the set of unique columns is projected + if !acc_projections.is_empty() { + if let Some(subset) = options.subset.as_ref() { + subset.iter().for_each(|name| { + add_str_to_accumulated( + name, &mut acc_projections, &mut projected_names, expr_arena, - ); - }); - } - - self.pushdown_and_assign( - input, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - )?; - Ok(Sort { - input, - by_column, - args, - }) - }, - Distinct { input, options } => { - // make sure that the set of unique columns is projected - if !acc_projections.is_empty() { - if let Some(subset) = options.subset.as_ref() { - subset.iter().for_each(|name| { - add_str_to_accumulated( - name, - &mut acc_projections, - &mut projected_names, - expr_arena, - ) - }) - } else { - // distinct needs all columns - let input_schema = lp_arena.get(input).schema(lp_arena); - for name in input_schema.iter_names() { - add_str_to_accumulated( - name.as_str(), - &mut acc_projections, - &mut projected_names, - expr_arena, - ) - } + ) + }) + } else { + // distinct needs all columns + let input_schema = lp_arena.get(input).schema(lp_arena); + for name in input_schema.iter_names() { + add_str_to_accumulated( + name.as_str(), + &mut acc_projections, + &mut projected_names, + expr_arena, + ) } } + } - self.pushdown_and_assign( - input, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - )?; - Ok(Distinct { input, options }) - }, - Selection { predicate, input } => { - if !acc_projections.is_empty() { - // make sure that the filter column is projected - add_expr_to_accumulated( - predicate.node(), - &mut acc_projections, - &mut projected_names, - expr_arena, - ); - }; - self.pushdown_and_assign( - input, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - )?; - Ok(Selection { predicate, input }) - }, - Aggregate { + self.pushdown_and_assign( input, - keys, - aggs, - apply, - schema, - maintain_order, - options, - } => process_group_by( - self, - input, - keys, - aggs, - apply, - schema, - maintain_order, - options, acc_projections, projected_names, projections_seen, lp_arena, expr_arena, - ), - Join { - input_left, - input_right, - left_on, - right_on, - options, - schema, - } => match options.args.how { - #[cfg(feature = "semi_anti_join")] - JoinType::Semi | JoinType::Anti => process_semi_anti_join( - self, - input_left, - input_right, - left_on, - right_on, - options, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - ), - _ => process_join( - self, - input_left, - input_right, - left_on, - right_on, - options, - acc_projections, - projected_names, - projections_seen, - lp_arena, + )?; + Ok(Distinct { input, options }) + }, + Selection { predicate, input } => { + if !acc_projections.is_empty() { + // make sure that the filter column is projected + add_expr_to_accumulated( + predicate.node(), + &mut acc_projections, + &mut projected_names, expr_arena, - &schema, - ), - }, - HStack { - input, - exprs, - options, - .. - } => process_hstack( - self, + ); + }; + self.pushdown_and_assign( input, - exprs.exprs(), - options, acc_projections, projected_names, projections_seen, lp_arena, expr_arena, - ), - ExtContext { - input, contexts, .. - } => { - // local projections are ignored. These are just root nodes - // complex expression will still be done later - let _local_projections = self.pushdown_and_assign_check_schema( - input, - acc_projections, - projections_seen, - lp_arena, - expr_arena, - false, - )?; - - let mut new_schema = lp_arena - .get(input) - .schema(lp_arena) - .as_ref() - .as_ref() - .clone(); - - for node in &contexts { - let other_schema = lp_arena.get(*node).schema(lp_arena); - for fld in other_schema.iter_fields() { - if new_schema.get(fld.name()).is_none() { - new_schema.with_column(fld.name, fld.dtype); - } - } - } - - Ok(ExtContext { - input, - contexts, - schema: Arc::new(new_schema), - }) - }, - MapFunction { - input, - ref function, - } => functions::process_functions( + )?; + Ok(Selection { predicate, input }) + }, + Aggregate { + input, + keys, + aggs, + apply, + schema, + maintain_order, + options, + } => process_group_by( + self, + input, + keys, + aggs, + apply, + schema, + maintain_order, + options, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), + Join { + input_left, + input_right, + left_on, + right_on, + options, + schema, + } => match options.args.how { + #[cfg(feature = "semi_anti_join")] + JoinType::Semi | JoinType::Anti => process_semi_anti_join( self, - input, - function, + input_left, + input_right, + left_on, + right_on, + options, acc_projections, projected_names, projections_seen, lp_arena, expr_arena, ), - HConcat { - inputs, - schema, - options, - } => process_hconcat( + _ => process_join( self, - inputs, - schema, + input_left, + input_right, + left_on, + right_on, options, acc_projections, - projections_seen, - lp_arena, - expr_arena, - ), - lp @ Union { .. } => process_generic( - self, - lp, - acc_projections, projected_names, projections_seen, lp_arena, expr_arena, + &schema, ), - // These nodes only have inputs and exprs, so we can use same logic. - lp @ Slice { .. } | lp @ Sink { .. } => process_generic( - self, - lp, + }, + HStack { + input, + exprs, + options, + .. + } => process_hstack( + self, + input, + exprs.exprs(), + options, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), + ExtContext { + input, contexts, .. + } => { + // local projections are ignored. These are just root nodes + // complex expression will still be done later + let _local_projections = self.pushdown_and_assign_check_schema( + input, acc_projections, - projected_names, projections_seen, lp_arena, expr_arena, - ), - Cache { .. } => { - // projections above this cache will be accumulated and pushed down - // later - // the redundant projection will be cleaned in the fast projection optimization - // phase. - if acc_projections.is_empty() { - Ok(logical_plan) - } else { - Ok( - ALogicalPlanBuilder::from_lp(logical_plan, expr_arena, lp_arena) - .project_simple_nodes(acc_projections) - .unwrap() - .build(), - ) + false, + )?; + + let mut new_schema = lp_arena + .get(input) + .schema(lp_arena) + .as_ref() + .as_ref() + .clone(); + + for node in &contexts { + let other_schema = lp_arena.get(*node).schema(lp_arena); + for fld in other_schema.iter_fields() { + if new_schema.get(fld.name()).is_none() { + new_schema.with_column(fld.name, fld.dtype); + } } - }, - Invalid => unreachable!(), - } - }) + } + + Ok(ExtContext { + input, + contexts, + schema: Arc::new(new_schema), + }) + }, + MapFunction { + input, + ref function, + } => functions::process_functions( + self, + input, + function, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), + HConcat { + inputs, + schema, + options, + } => process_hconcat( + self, + inputs, + schema, + options, + acc_projections, + projections_seen, + lp_arena, + expr_arena, + ), + lp @ Union { .. } => process_generic( + self, + lp, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), + // These nodes only have inputs and exprs, so we can use same logic. + lp @ Slice { .. } | lp @ Sink { .. } => process_generic( + self, + lp, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), + Cache { .. } => { + // projections above this cache will be accumulated and pushed down + // later + // the redundant projection will be cleaned in the fast projection optimization + // phase. + if acc_projections.is_empty() { + Ok(logical_plan) + } else { + Ok( + ALogicalPlanBuilder::from_lp(logical_plan, expr_arena, lp_arena) + .project_simple_nodes(acc_projections) + .unwrap() + .build(), + ) + } + }, + Invalid => unreachable!(), + } } pub fn optimize( diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 40bc01aeb53a..37326f4db1ff 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -1,5 +1,5 @@ use polars_core::prelude::*; -use polars_utils::recursion::with_dynamic_stack; +use recursive::recursive; use crate::logical_plan::projection_expr::ProjectionExprs; use crate::prelude::*; @@ -139,6 +139,7 @@ impl SlicePushDown { Ok(lp.with_exprs_and_input(exprs, new_inputs)) } + #[recursive] fn pushdown( &self, lp: ALogicalPlan, @@ -146,267 +147,265 @@ impl SlicePushDown { lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - with_dynamic_stack(|| { - use ALogicalPlan::*; + use ALogicalPlan::*; - match (lp, state) { - #[cfg(feature = "python")] - (PythonScan { - mut options, - predicate, - }, - // TODO! we currently skip slice pushdown if there is a predicate. - // we can modify the readers to only limit after predicates have been applied - Some(state)) if state.offset == 0 && predicate.is_none() => { - options.n_rows = Some(state.len as usize); - let lp = PythonScan { - options, - predicate - }; - Ok(lp) - } - #[cfg(feature = "csv")] - (Scan { + match (lp, state) { + #[cfg(feature = "python")] + (PythonScan { + mut options, + predicate, + }, + // TODO! we currently skip slice pushdown if there is a predicate. + // we can modify the readers to only limit after predicates have been applied + Some(state)) if state.offset == 0 && predicate.is_none() => { + options.n_rows = Some(state.len as usize); + let lp = PythonScan { + options, + predicate + }; + Ok(lp) + } + #[cfg(feature = "csv")] + (Scan { + paths, + file_info, + output_schema, + file_options: mut options, + predicate, + scan_type: FileScan::Csv {options: mut csv_options} + }, Some(state)) if predicate.is_none() && state.offset >= 0 => { + options.n_rows = Some(state.len as usize); + csv_options.skip_rows += state.offset as usize; + + let lp = Scan { paths, file_info, output_schema, - file_options: mut options, + scan_type: FileScan::Csv {options: csv_options}, + file_options: options, predicate, - scan_type: FileScan::Csv {options: mut csv_options} - }, Some(state)) if predicate.is_none() && state.offset >= 0 => { - options.n_rows = Some(state.len as usize); - csv_options.skip_rows += state.offset as usize; - - let lp = Scan { - paths, - file_info, - output_schema, - scan_type: FileScan::Csv {options: csv_options}, - file_options: options, - predicate, - }; - Ok(lp) - }, - // TODO! we currently skip slice pushdown if there is a predicate. - (Scan { + }; + Ok(lp) + }, + // TODO! we currently skip slice pushdown if there is a predicate. + (Scan { + paths, + file_info, + output_schema, + file_options: mut options, + predicate, + scan_type + }, Some(state)) if state.offset == 0 && predicate.is_none() => { + options.n_rows = Some(state.len as usize); + let lp = Scan { paths, file_info, output_schema, - file_options: mut options, predicate, + file_options: options, scan_type - }, Some(state)) if state.offset == 0 && predicate.is_none() => { - options.n_rows = Some(state.len as usize); - let lp = Scan { - paths, - file_info, - output_schema, - predicate, - file_options: options, - scan_type - }; + }; - Ok(lp) - } - (Union {mut inputs, mut options }, Some(state)) => { - options.slice = Some((state.offset, state.len as usize)); - if state.offset == 0 { - for input in &mut inputs { - let input_lp = lp_arena.take(*input); - let input_lp = self.pushdown(input_lp, Some(state), lp_arena, expr_arena)?; - lp_arena.replace(*input, input_lp); - } + Ok(lp) + } + (Union {mut inputs, mut options }, Some(state)) => { + options.slice = Some((state.offset, state.len as usize)); + if state.offset == 0 { + for input in &mut inputs { + let input_lp = lp_arena.take(*input); + let input_lp = self.pushdown(input_lp, Some(state), lp_arena, expr_arena)?; + lp_arena.replace(*input, input_lp); } - Ok(Union {inputs, options}) - }, - (Join { + } + Ok(Union {inputs, options}) + }, + (Join { + input_left, + input_right, + schema, + left_on, + right_on, + mut options + }, Some(state)) if !self.streaming => { + // first restart optimization in both inputs and get the updated LP + let lp_left = lp_arena.take(input_left); + let lp_left = self.pushdown(lp_left, None, lp_arena, expr_arena)?; + let input_left = lp_arena.add(lp_left); + + let lp_right = lp_arena.take(input_right); + let lp_right = self.pushdown(lp_right, None, lp_arena, expr_arena)?; + let input_right = lp_arena.add(lp_right); + + // then assign the slice state to the join operation + + let mut_options = Arc::make_mut(&mut options); + mut_options.args.slice = Some((state.offset, state.len as usize)); + + Ok(Join { input_left, input_right, schema, left_on, right_on, - mut options - }, Some(state)) if !self.streaming => { - // first restart optimization in both inputs and get the updated LP - let lp_left = lp_arena.take(input_left); - let lp_left = self.pushdown(lp_left, None, lp_arena, expr_arena)?; - let input_left = lp_arena.add(lp_left); - - let lp_right = lp_arena.take(input_right); - let lp_right = self.pushdown(lp_right, None, lp_arena, expr_arena)?; - let input_right = lp_arena.add(lp_right); - - // then assign the slice state to the join operation - - let mut_options = Arc::make_mut(&mut options); - mut_options.args.slice = Some((state.offset, state.len as usize)); - - Ok(Join { - input_left, - input_right, - schema, - left_on, - right_on, - options - }) - } - (Aggregate { input, keys, aggs, schema, apply, maintain_order, mut options }, Some(state)) => { - // first restart optimization in inputs and get the updated LP - let input_lp = lp_arena.take(input); - let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; - let input= lp_arena.add(input_lp); + options + }) + } + (Aggregate { input, keys, aggs, schema, apply, maintain_order, mut options }, Some(state)) => { + // first restart optimization in inputs and get the updated LP + let input_lp = lp_arena.take(input); + let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; + let input= lp_arena.add(input_lp); - let mut_options= Arc::make_mut(&mut options); - mut_options.slice = Some((state.offset, state.len as usize)); + let mut_options= Arc::make_mut(&mut options); + mut_options.slice = Some((state.offset, state.len as usize)); - Ok(Aggregate { - input, - keys, - aggs, - schema, - apply, - maintain_order, - options - }) - } - (Distinct {input, mut options}, Some(state)) => { - // first restart optimization in inputs and get the updated LP - let input_lp = lp_arena.take(input); - let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; - let input= lp_arena.add(input_lp); - options.slice = Some((state.offset, state.len as usize)); - Ok(Distinct { - input, - options, - }) - } - (Sort {input, by_column, mut args}, Some(state)) => { - // first restart optimization in inputs and get the updated LP - let input_lp = lp_arena.take(input); - let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; - let input= lp_arena.add(input_lp); + Ok(Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options + }) + } + (Distinct {input, mut options}, Some(state)) => { + // first restart optimization in inputs and get the updated LP + let input_lp = lp_arena.take(input); + let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; + let input= lp_arena.add(input_lp); + options.slice = Some((state.offset, state.len as usize)); + Ok(Distinct { + input, + options, + }) + } + (Sort {input, by_column, mut args}, Some(state)) => { + // first restart optimization in inputs and get the updated LP + let input_lp = lp_arena.take(input); + let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; + let input= lp_arena.add(input_lp); - args.slice = Some((state.offset, state.len as usize)); - Ok(Sort { - input, - by_column, - args - }) - } - (Slice { + args.slice = Some((state.offset, state.len as usize)); + Ok(Sort { input, - offset, - len - }, Some(previous_state)) => { - let alp = lp_arena.take(input); - let state = Some(if previous_state.offset == offset { - State { - offset, - len: std::cmp::min(len, previous_state.len) - } - } else { - State { - offset, - len - } - }); - let lp = self.pushdown(alp, state, lp_arena, expr_arena)?; - let input = lp_arena.add(lp); - Ok(Slice { - input, - offset: previous_state.offset, - len: previous_state.len - }) - } - (Slice { + by_column, + args + }) + } + (Slice { + input, + offset, + len + }, Some(previous_state)) => { + let alp = lp_arena.take(input); + let state = Some(if previous_state.offset == offset { + State { + offset, + len: std::cmp::min(len, previous_state.len) + } + } else { + State { + offset, + len + } + }); + let lp = self.pushdown(alp, state, lp_arena, expr_arena)?; + let input = lp_arena.add(lp); + Ok(Slice { input, + offset: previous_state.offset, + len: previous_state.len + }) + } + (Slice { + input, + offset, + len + }, None) => { + let alp = lp_arena.take(input); + let state = Some(State { offset, len - }, None) => { - let alp = lp_arena.take(input); - let state = Some(State { - offset, - len - }); - self.pushdown(alp, state, lp_arena, expr_arena) - } - // [Do not pushdown] boundary - // here we do not pushdown. - // we reset the state and then start the optimization again - m @ (Selection { .. }, _) - // other blocking nodes - | m @ (DataFrameScan {..}, _) - | m @ (Sort {..}, _) - | m @ (MapFunction {function: FunctionNode::Explode {..}, ..}, _) - | m @ (MapFunction {function: FunctionNode::Melt {..}, ..}, _) - | m @ (Cache {..}, _) - | m @ (Distinct {..}, _) - | m @ (Aggregate{..},_) - // blocking in streaming - | m @ (Join{..},_) - => { - let (lp, state) = m; - self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) - } - // [Pushdown] - (MapFunction {input, function}, _) if function.allow_predicate_pd() => { - let lp = MapFunction {input, function}; - self.pushdown_and_continue(lp, state, lp_arena, expr_arena) - }, - // [NO Pushdown] - m @ (MapFunction {..}, _) => { - let (lp, state) = m; - self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) - } - // [Pushdown] - // these nodes will be pushed down. - // State is None, we can continue - m @(Projection{..}, None) - => { - let (lp, state) = m; + }); + self.pushdown(alp, state, lp_arena, expr_arena) + } + // [Do not pushdown] boundary + // here we do not pushdown. + // we reset the state and then start the optimization again + m @ (Selection { .. }, _) + // other blocking nodes + | m @ (DataFrameScan {..}, _) + | m @ (Sort {..}, _) + | m @ (MapFunction {function: FunctionNode::Explode {..}, ..}, _) + | m @ (MapFunction {function: FunctionNode::Melt {..}, ..}, _) + | m @ (Cache {..}, _) + | m @ (Distinct {..}, _) + | m @ (Aggregate{..},_) + // blocking in streaming + | m @ (Join{..},_) + => { + let (lp, state) = m; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + } + // [Pushdown] + (MapFunction {input, function}, _) if function.allow_predicate_pd() => { + let lp = MapFunction {input, function}; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + }, + // [NO Pushdown] + m @ (MapFunction {..}, _) => { + let (lp, state) = m; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + } + // [Pushdown] + // these nodes will be pushed down. + // State is None, we can continue + m @(Projection{..}, None) + => { + let (lp, state) = m; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + // there is state, inspect the projection to determine how to deal with it + (Projection {input, expr, schema, options}, Some(_)) => { + if can_pushdown_slice_past_projections(&expr, expr_arena).1 { + let lp = Projection {input, expr, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } - // there is state, inspect the projection to determine how to deal with it - (Projection {input, expr, schema, options}, Some(_)) => { - if can_pushdown_slice_past_projections(&expr, expr_arena).1 { - let lp = Projection {input, expr, schema, options}; - self.pushdown_and_continue(lp, state, lp_arena, expr_arena) - } - // don't push down slice, but restart optimization - else { - let lp = Projection {input, expr, schema, options}; - self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) - } + // don't push down slice, but restart optimization + else { + let lp = Projection {input, expr, schema, options}; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) } - (HStack {input, exprs, schema, options}, _) => { - let check = can_pushdown_slice_past_projections(&exprs, expr_arena); + } + (HStack {input, exprs, schema, options}, _) => { + let check = can_pushdown_slice_past_projections(&exprs, expr_arena); - if ( - // If the schema length is greater then an input column is being projected, so - // the exprs in with_columns do not need to have an input column name. - schema.len() > exprs.len() && check.0 - ) - || check.1 // e.g. select(c).with_columns(c = c + 1) - { - let lp = HStack {input, exprs, schema, options}; - self.pushdown_and_continue(lp, state, lp_arena, expr_arena) - } - // don't push down slice, but restart optimization - else { - let lp = HStack {input, exprs, schema, options}; - self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) - } - } - (HConcat {inputs, schema, options}, _) => { - // Slice can always be pushed down for horizontal concatenation - let lp = HConcat {inputs, schema, options}; + if ( + // If the schema length is greater then an input column is being projected, so + // the exprs in with_columns do not need to have an input column name. + schema.len() > exprs.len() && check.0 + ) + || check.1 // e.g. select(c).with_columns(c = c + 1) + { + let lp = HStack {input, exprs, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } - (catch_all, state) => { - self.no_pushdown_finish_opt(catch_all, state, lp_arena) + // don't push down slice, but restart optimization + else { + let lp = HStack {input, exprs, schema, options}; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) } } - }) + (HConcat {inputs, schema, options}, _) => { + // Slice can always be pushed down for horizontal concatenation + let lp = HConcat {inputs, schema, options}; + self.pushdown_and_continue(lp, state, lp_arena, expr_arena) + } + (catch_all, state) => { + self.no_pushdown_finish_opt(catch_all, state, lp_arena) + } + } } pub fn optimize( diff --git a/crates/polars-plan/src/logical_plan/visitor/visitors.rs b/crates/polars-plan/src/logical_plan/visitor/visitors.rs index 3f3e3436f401..5c8e06877d5a 100644 --- a/crates/polars-plan/src/logical_plan/visitor/visitors.rs +++ b/crates/polars-plan/src/logical_plan/visitor/visitors.rs @@ -1,4 +1,4 @@ -use polars_utils::recursion::with_dynamic_stack; +use recursive::recursive; use super::*; @@ -13,44 +13,42 @@ pub trait TreeWalker: Sized { fn map_children(self, op: &mut dyn FnMut(Self) -> PolarsResult) -> PolarsResult; /// Walks all nodes in depth-first-order. + #[recursive] fn visit(&self, visitor: &mut dyn Visitor) -> PolarsResult { - with_dynamic_stack(|| { - match visitor.pre_visit(self)? { - VisitRecursion::Continue => {}, - // If the recursion should skip, do not apply to its children. And let the recursion continue - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - // If the recursion should stop, do not apply to its children - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - }; + match visitor.pre_visit(self)? { + VisitRecursion::Continue => {}, + // If the recursion should skip, do not apply to its children. And let the recursion continue + VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + // If the recursion should stop, do not apply to its children + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + }; - match self.apply_children(&mut |node| node.visit(visitor))? { - // let the recursion continue - VisitRecursion::Continue | VisitRecursion::Skip => {}, - // If the recursion should stop, no further post visit will be performed - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), - } + match self.apply_children(&mut |node| node.visit(visitor))? { + // let the recursion continue + VisitRecursion::Continue | VisitRecursion::Skip => {}, + // If the recursion should stop, no further post visit will be performed + VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + } - visitor.post_visit(self) - }) + visitor.post_visit(self) } + #[recursive] fn rewrite(self, rewriter: &mut dyn RewritingVisitor) -> PolarsResult { - with_dynamic_stack(|| { - let mutate_this_node = match rewriter.pre_visit(&self)? { - RewriteRecursion::MutateAndStop => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::MutateAndContinue => true, - RewriteRecursion::NoMutateAndContinue => false, - }; + let mutate_this_node = match rewriter.pre_visit(&self)? { + RewriteRecursion::MutateAndStop => return rewriter.mutate(self), + RewriteRecursion::Stop => return Ok(self), + RewriteRecursion::MutateAndContinue => true, + RewriteRecursion::NoMutateAndContinue => false, + }; - let after_applied_children = self.map_children(&mut |node| node.rewrite(rewriter))?; + let after_applied_children = self.map_children(&mut |node| node.rewrite(rewriter))?; - if mutate_this_node { - rewriter.mutate(after_applied_children) - } else { - Ok(after_applied_children) - } - }) + if mutate_this_node { + rewriter.mutate(after_applied_children) + } else { + Ok(after_applied_children) + } } } diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index c1e5673f1720..575571b62985 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -14,7 +14,6 @@ pub mod hashing; pub mod idx_vec; pub mod mem; pub mod min_max; -pub mod recursion; pub mod slice; pub mod sort; pub mod sync; diff --git a/crates/polars-utils/src/recursion.rs b/crates/polars-utils/src/recursion.rs deleted file mode 100644 index a2db919dd07c..000000000000 --- a/crates/polars-utils/src/recursion.rs +++ /dev/null @@ -1,6 +0,0 @@ -const STACK_SIZE_GUARANTEE: usize = 256 * 1024; -const STACK_ALLOC_SIZE: usize = 2 * 1024 * 1024; - -pub fn with_dynamic_stack R>(f: F) -> R { - stacker::maybe_grow(STACK_SIZE_GUARANTEE, STACK_ALLOC_SIZE, f) -} From b808bc04fed17b03f62ebfe9bd288b3e9abd6e5d Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Sat, 30 Mar 2024 18:18:57 +1100 Subject: [PATCH 14/30] fix: Hash failure combining hash of two numeric columns containing equal values (#15397) --- .../polars-core/src/hashing/vector_hasher.rs | 12 +++++++---- py-polars/tests/unit/dataframe/test_df.py | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 1c7635e701b1..dc20a2baa19d 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -112,7 +112,13 @@ where .iter() .zip(&mut hashes[offset..]) .for_each(|(v, h)| { - *h = folded_multiply(random_state.hash_one(v.to_total_ord()) ^ *h, MULTIPLE); + // Inlined from ahash. This ensures we combine with the previous state. + *h = folded_multiply( + // Be careful not to xor the hash directly with the existing hash, + // it would lead to 0-hashes for 2 columns containing equal values. + random_state.hash_one(v.to_total_ord()) ^ folded_multiply(*h, MULTIPLE), + MULTIPLE, + ); }), _ => { let validity = arr.validity().unwrap(); @@ -124,9 +130,7 @@ where .for_each(|((valid, h), l)| { let lh = random_state.hash_one(l.to_total_ord()); let to_hash = [null_h, lh][valid as usize]; - - // inlined from ahash. This ensures we combine with the previous state - *h = folded_multiply(to_hash ^ *h, MULTIPLE); + *h = folded_multiply(to_hash ^ folded_multiply(*h, MULTIPLE), MULTIPLE); }); }, } diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 3b8123c65066..2ca9f40a88db 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1441,6 +1441,27 @@ def test_reproducible_hash_with_seeds() -> None: assert_series_equal(expected, result, check_names=False, check_exact=True) +@pytest.mark.slow() +@pytest.mark.parametrize( + "e", + [ + pl.int_range(1_000_000), + # Test code path for null_count > 0 + pl.when(pl.int_range(1_000_000) != 0).then(pl.int_range(1_000_000)), + ], +) +def test_hash_collision_multiple_columns_equal_values_15390(e: pl.Expr) -> None: + df = pl.select(e.alias("a")) + + for n_columns in (1, 2, 3, 4): + s = df.select(pl.col("a").alias(f"x{i}") for i in range(n_columns)).hash_rows() + + vc = s.sort().value_counts(sort=True) + max_bucket_size = vc["count"][0] + + assert max_bucket_size == 1 + + def test_hashing_on_python_objects() -> None: # see if we can do a group_by, drop_duplicates on a DataFrame with objects. # this requires that the hashing and aggregations are done on python objects From c39ccae42d857cd1405d2b65e4ac1c67f58d01a9 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 30 Mar 2024 09:11:34 +0100 Subject: [PATCH 15/30] perf: Use row-encoding for multiple key group by (#15392) --- .../src/chunked_array/object/mod.rs | 15 + .../ops/sort/arg_sort_multiple.rs | 10 +- .../src/chunked_array/ops/sort/mod.rs | 4 +- crates/polars-core/src/datatypes/any_value.rs | 2 + crates/polars-core/src/datatypes/dtype.rs | 15 + .../polars-core/src/frame/group_by/hashing.rs | 298 ++---------------- crates/polars-core/src/frame/group_by/mod.rs | 77 ++--- crates/polars-core/src/frame/mod.rs | 2 +- crates/polars-core/src/frame/row/mod.rs | 73 +++++ .../src/series/implementations/struct_.rs | 3 +- crates/polars-utils/src/total_ord.rs | 85 ++--- 11 files changed, 210 insertions(+), 374 deletions(-) diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index c834a61fa990..94a6c203d19a 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -33,6 +33,14 @@ pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display { fn as_any(&self) -> &dyn Any; fn to_boxed(&self) -> Box; + + fn equal(&self, other: &dyn PolarsObjectSafe) -> bool; +} + +impl PartialEq for &dyn PolarsObjectSafe { + fn eq(&self, other: &Self) -> bool { + self.equal(*other) + } } /// Values need to implement this so that they can be stored into a Series and DataFrame @@ -55,6 +63,13 @@ impl PolarsObjectSafe for T { fn to_boxed(&self) -> Box { Box::new(self.clone()) } + + fn equal(&self, other: &dyn PolarsObjectSafe) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + self == other + } } pub type ObjectValueIter<'a, T> = std::slice::Iter<'a, T>; diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 44a67af75294..4deffaab0165 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -3,7 +3,6 @@ use polars_row::{convert_columns, RowsEncoded, SortField}; use polars_utils::iter::EnumerateIdxTrait; use super::*; -#[cfg(feature = "dtype-struct")] use crate::utils::_split_offsets; pub(crate) fn args_validate( @@ -88,8 +87,7 @@ pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { Ok(out) } -#[cfg(feature = "dtype-struct")] -pub(crate) fn encode_rows_vertical(by: &[Series]) -> PolarsResult { +pub(crate) fn encode_rows_vertical_par_default(by: &[Series]) -> PolarsResult { let n_threads = POOL.current_num_threads(); let len = by[0].len(); let splits = _split_offsets(len, n_threads); @@ -108,6 +106,12 @@ pub(crate) fn encode_rows_vertical(by: &[Series]) -> PolarsResult PolarsResult { + let descending = vec![false; by.len()]; + let rows = _get_rows_encoded(by, &descending, false)?; + Ok(BinaryOffsetChunked::with_chunk("", rows.into_array())) +} + pub fn _get_rows_encoded( by: &[Series], descending: &[bool], diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 1610ea6b2fb4..2611227031f7 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -15,9 +15,7 @@ use rayon::prelude::*; pub use slice::*; use crate::prelude::compare_inner::TotalOrdInner; -#[cfg(feature = "dtype-struct")] -use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; -use crate::prelude::sort::arg_sort_multiple::{arg_sort_multiple_impl, args_validate}; +use crate::prelude::sort::arg_sort_multiple::*; use crate::prelude::*; use crate::series::IsSorted; use crate::utils::NoNull; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 523b7a9939d4..251db06b0044 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -946,6 +946,8 @@ impl AnyValue<'_> { // 1.2 at scale 1, and 1.20 at scale 2, are not equal. *v_l == *v_r && *scale_l == *scale_r }, + #[cfg(feature = "object")] + (Object(l), Object(r)) => l == r, _ => false, } } diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index db1f52497ff6..3f7b58a6379a 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -241,6 +241,21 @@ impl DataType { matches!(self, DataType::Binary) } + pub fn is_object(&self) -> bool { + #[cfg(feature = "object")] + { + matches!(self, DataType::Object(_, _)) + } + #[cfg(not(feature = "object"))] + { + false + } + } + + pub fn is_null(&self) -> bool { + matches!(self, DataType::Null) + } + pub fn contains_views(&self) -> bool { use DataType::*; match self { diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index b3e85e5dacb5..418471abc388 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -1,19 +1,16 @@ -use std::hash::{BuildHasher, Hash}; +use std::hash::{BuildHasher, Hash, Hasher}; -use hashbrown::hash_map::{Entry, RawEntryMut}; -use hashbrown::HashMap; +use hashbrown::hash_map::RawEntryMut; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; -use polars_utils::iter::EnumerateIdxTrait; use polars_utils::sync::SyncPtr; use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use polars_utils::unitvec; use rayon::prelude::*; use crate::hashing::*; -use crate::prelude::compare_inner::TotalEqInner; use crate::prelude::*; -use crate::utils::{flatten, split_df}; +use crate::utils::flatten; use crate::POOL; fn get_init_size() -> usize { @@ -76,89 +73,33 @@ fn finish_group_order(mut out: Vec>, sorted: bool) -> GroupsProxy { } } -// The inner vecs should be sorted by [`IdxSize`] -// the group_by multiple keys variants suffice -// this requirements as they use an [`IdxMap`] strategy -fn finish_group_order_vecs( - mut vecs: Vec<(Vec, Vec)>, - sorted: bool, -) -> GroupsProxy { - if sorted { - if vecs.len() == 1 { - let (first, all) = vecs.pop().unwrap(); - return GroupsProxy::Idx(GroupsIdx::new(first, all, true)); - } - - let cap = vecs.iter().map(|v| v.0.len()).sum::(); - let offsets = vecs - .iter() - .scan(0_usize, |acc, v| { - let out = *acc; - *acc += v.0.len(); - Some(out) - }) - .collect::>(); - - // we write (first, all) tuple because of sorting - let mut items = Vec::with_capacity(cap); - let items_ptr = unsafe { SyncPtr::new(items.as_mut_ptr()) }; - - POOL.install(|| { - vecs.into_par_iter() - .zip(offsets) - .for_each(|((first, all), offset)| { - // pre-sort every array not needed as items are already sorted - // this is due to using an index hashmap - - unsafe { - let mut items_ptr: *mut (IdxSize, IdxVec) = items_ptr.get(); - items_ptr = items_ptr.add(offset); - - // give the compiler some info - // maybe it may elide some loop counters - assert_eq!(first.len(), all.len()); - for (i, (first, all)) in first.into_iter().zip(all).enumerate() { - std::ptr::write(items_ptr.add(i), (first, all)) - } - } - }); - }); - unsafe { - items.set_len(cap); - } - // sort again - items.sort_unstable_by_key(|g| g.0); - - let mut idx = GroupsIdx::from_iter(items); - idx.sorted = true; - GroupsProxy::Idx(idx) - } else { - // this materialization is parallel in the from impl. - GroupsProxy::Idx(GroupsIdx::from(vecs)) - } -} - pub(crate) fn group_by(a: impl Iterator, sorted: bool) -> GroupsProxy where - T: TotalHash + TotalEq + ToTotalOrd, - ::TotalOrdItem: Hash + Eq, + T: TotalHash + TotalEq, { let init_size = get_init_size(); - let mut hash_tbl: PlHashMap = - PlHashMap::with_capacity(init_size); + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); + let hasher = hash_tbl.hasher().clone(); let mut cnt = 0; a.for_each(|k| { - let k = k.to_total_ord(); let idx = cnt; cnt += 1; - let entry = hash_tbl.entry(k); + + let mut state = hasher.build_hasher(); + k.tot_hash(&mut state); + let h = state.finish(); + let entry = hash_tbl.raw_entry_mut().from_hash(h, |k_| k.tot_eq(k_)); match entry { - Entry::Vacant(entry) => { + RawEntryMut::Vacant(entry) => { let tuples = unitvec![idx]; - entry.insert((idx, tuples)); + entry.insert_with_hasher(h, k, (idx, tuples), |k| { + let mut state = hasher.build_hasher(); + k.tot_hash(&mut state); + state.finish() + }); }, - Entry::Occupied(mut entry) => { + RawEntryMut::Occupied(mut entry) => { let v = entry.get_mut(); v.1.push(idx); }, @@ -318,206 +259,3 @@ where }); finish_group_order(out, sorted) } - -#[inline] -pub(crate) unsafe fn compare_keys<'a>( - keys_cmp: &'a [Box], - idx_a: usize, - idx_b: usize, -) -> bool { - for cmp in keys_cmp { - if !cmp.eq_element_unchecked(idx_a, idx_b) { - return false; - } - } - true -} - -// Differs in the because this one uses the TotalEqInner trait objects -// is faster when multiple chunks. Not yet used in join. -pub(crate) fn populate_multiple_key_hashmap2<'a, V, H, F, G>( - hash_tbl: &mut HashMap, - // row index - idx: IdxSize, - // hash - original_h: u64, - // keys of the hash table (will not be inserted, the indexes will be used) - // the keys are needed for the equality check - keys_cmp: &'a [Box], - // value to insert - vacant_fn: G, - // function that gets a mutable ref to the occupied value in the hash table - occupied_fn: F, -) where - G: Fn() -> V, - F: Fn(&mut V), - H: BuildHasher, -{ - let entry = hash_tbl - .raw_entry_mut() - // uses the idx to probe rows in the original DataFrame with keys - // to check equality to find an entry - // this does not invalidate the hashmap as this equality function is not used - // during rehashing/resize (then the keys are already known to be unique). - // Only during insertion and probing an equality function is needed - .from_hash(original_h, |idx_hash| { - // first check the hash values before we incur - // cache misses - original_h == idx_hash.hash && { - let key_idx = idx_hash.idx; - // SAFETY: - // indices in a group_by operation are always in bounds. - unsafe { compare_keys(keys_cmp, key_idx as usize, idx as usize) } - } - }); - match entry { - RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn()); - }, - RawEntryMut::Occupied(mut entry) => { - let (_k, v) = entry.get_key_value_mut(); - occupied_fn(v); - }, - } -} - -pub(crate) fn group_by_threaded_multiple_keys_flat( - mut keys: DataFrame, - n_partitions: usize, - sorted: bool, -) -> PolarsResult { - let dfs = split_df(&mut keys, n_partitions).unwrap(); - let (hashes, _random_state) = _df_rows_to_hashes_threaded_vertical(&dfs, None)?; - - let init_size = get_init_size(); - - // trait object to compare inner types. - let keys_cmp = keys - .iter() - .map(|s| s.into_total_eq_inner()) - .collect::>(); - - // We will create a hashtable in every thread. - // We use the hash to partition the keys to the matching hashtable. - // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - let v = POOL.install(|| { - (0..n_partitions) - .into_par_iter() - .map(|thread_no| { - let hashes = &hashes; - - // IndexMap, the indexes are stored in flat vectors - // this ensures that order remains and iteration is fast - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(init_size, Default::default()); - let mut first_vals = Vec::with_capacity(init_size); - let mut all_vals = Vec::with_capacity(init_size); - - // put the buffers behind a pointer so we can access them from as the bchk doesn't allow - // 2 mutable borrows (this is safe as we don't alias) - // even if the vecs reallocate, we have a pointer to the stack vec, and thus always - // access the proper data. - let all_buf_ptr = &mut all_vals as *mut Vec as *const Vec; - let first_buf_ptr = &mut first_vals as *mut Vec as *const Vec; - - let mut offset = 0; - for hashes in hashes { - let len = hashes.len() as IdxSize; - - let mut idx = 0; - for hashes_chunk in hashes.data_views() { - for &h in hashes_chunk { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if thread_no == hash_to_partition(h, n_partitions) { - let row_idx = idx + offset; - populate_multiple_key_hashmap2( - &mut hash_tbl, - row_idx, - h, - &keys_cmp, - || unsafe { - let first_vals = &mut *(first_buf_ptr as *mut Vec); - let all_vals = &mut *(all_buf_ptr as *mut Vec); - let offset_idx = first_vals.len() as IdxSize; - - let tuples = unitvec![row_idx]; - all_vals.push(tuples); - first_vals.push(row_idx); - offset_idx - }, - |v| unsafe { - let all_vals = &mut *(all_buf_ptr as *mut Vec); - let offset_idx = *v; - let buf = all_vals.get_unchecked_mut(offset_idx as usize); - buf.push(row_idx) - }, - ); - } - idx += 1; - } - } - - offset += len; - } - (first_vals, all_vals) - }) - .collect::>() - }); - Ok(finish_group_order_vecs(v, sorted)) -} - -pub(crate) fn group_by_multiple_keys(keys: DataFrame, sorted: bool) -> PolarsResult { - let mut hashes = Vec::with_capacity(keys.height()); - let _ = series_to_hashes(keys.get_columns(), None, &mut hashes)?; - - let init_size = get_init_size(); - - // trait object to compare inner types. - let keys_cmp = keys - .iter() - .map(|s| s.into_total_eq_inner()) - .collect::>(); - - // IndexMap, the indexes are stored in flat vectors - // this ensures that order remains and iteration is fast - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(init_size, Default::default()); - let mut first_vals = Vec::with_capacity(init_size); - let mut all_vals = Vec::with_capacity(init_size); - - // put the buffers behind a pointer so we can access them from as the bchk doesn't allow - // 2 mutable borrows (this is safe as we don't alias) - // even if the vecs reallocate, we have a pointer to the stack vec, and thus always - // access the proper data. - let all_buf_ptr = &mut all_vals as *mut Vec as *const Vec; - let first_buf_ptr = &mut first_vals as *mut Vec as *const Vec; - - for (row_idx, h) in hashes.into_iter().enumerate_idx() { - populate_multiple_key_hashmap2( - &mut hash_tbl, - row_idx, - h, - &keys_cmp, - || unsafe { - let first_vals = &mut *(first_buf_ptr as *mut Vec); - let all_vals = &mut *(all_buf_ptr as *mut Vec); - let offset_idx = first_vals.len() as IdxSize; - - let tuples = unitvec![row_idx]; - all_vals.push(tuples); - first_vals.push(row_idx); - offset_idx - }, - |v| unsafe { - let all_vals = &mut *(all_buf_ptr as *mut Vec); - let offset_idx = *v; - let buf = all_vals.get_unchecked_mut(offset_idx as usize); - buf.push(row_idx) - }, - ); - } - - let v = vec![(first_vals, all_vals)]; - Ok(finish_group_order_vecs(v, sorted)) -} diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 75df1e198c50..bc7e90406cf7 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -22,37 +22,9 @@ mod proxy; pub use into_groups::*; pub use proxy::*; -#[cfg(feature = "dtype-struct")] -use crate::prelude::sort::arg_sort_multiple::encode_rows_vertical; - -// This will remove the sorted flag on signed integers -fn prepare_dataframe_unsorted(by: &[Series]) -> DataFrame { - let columns = by - .iter() - .map(|s| match s.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - s.cast(&DataType::UInt32).unwrap() - }, - _ => { - if s.dtype().to_physical().is_numeric() { - let s = s.to_physical_repr(); - - if s.dtype().is_float() { - s.into_owned().into_series() - } else if s.bit_repr_is_large() { - s.bit_repr_large().into_series() - } else { - s.bit_repr_small().into_series() - } - } else { - s.clone() - } - }, - }) - .collect(); - unsafe { DataFrame::new_no_checks(columns) } -} +use crate::prelude::sort::arg_sort_multiple::{ + encode_rows_default, encode_rows_vertical_par_default, +}; impl DataFrame { pub fn group_by_with_series( @@ -82,25 +54,42 @@ impl DataFrame { } }; - let n_partitions = _set_partition_size(); - let groups = if by.len() == 1 { let series = &by[0]; series.group_tuples(multithreaded, sorted) - } else { - #[cfg(feature = "dtype-struct")] + } else if by.iter().any(|s| s.dtype().is_object()) { + #[cfg(feature = "object")] { - if by.iter().any(|s| matches!(s.dtype(), DataType::Struct(_))) { - let rows = encode_rows_vertical(&by)?; - let groups = rows.group_tuples(multithreaded, sorted)?; - return Ok(GroupBy::new(self, by, groups, None)); - } + let mut df = DataFrame::new(by.clone()).unwrap(); + let n = df.height(); + let rows = df.to_av_rows(); + let iter = (0..n).map(|i| rows.get(i)); + Ok(group_by(iter, sorted)) + } + #[cfg(not(feature = "object"))] + { + unreachable!() } - let keys_df = prepare_dataframe_unsorted(&by); - if multithreaded { - group_by_threaded_multiple_keys_flat(keys_df, n_partitions, sorted) + } else { + // Skip null dtype. + let by = by + .iter() + .filter(|s| !s.dtype().is_null()) + .cloned() + .collect::>(); + if by.is_empty() { + Ok(GroupsProxy::Slice { + groups: vec![[0, self.height() as IdxSize]], + rolling: false, + }) } else { - group_by_multiple_keys(keys_df, sorted) + let rows = if multithreaded { + encode_rows_vertical_par_default(&by) + } else { + encode_rows_default(&by) + }? + .into_series(); + rows.group_tuples(multithreaded, sorted) } }; Ok(GroupBy::new(self, by, groups?, None)) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 0d8a696d190a..83c8292918f4 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -17,7 +17,7 @@ pub mod explode; mod from; #[cfg(feature = "algorithm_group_by")] pub mod group_by; -#[cfg(feature = "rows")] +#[cfg(any(feature = "rows", feature = "object"))] pub mod row; mod top_k; mod upstream_traits; diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs index 05701260416f..e9cf92ffad13 100644 --- a/crates/polars-core/src/frame/row/mod.rs +++ b/crates/polars-core/src/frame/row/mod.rs @@ -4,16 +4,89 @@ mod transpose; use std::borrow::Borrow; use std::fmt::Debug; +#[cfg(feature = "object")] +use std::hash::{Hash, Hasher}; use std::hint::unreachable_unchecked; use arrow::bitmap::Bitmap; pub use av_buffer::*; +#[cfg(feature = "object")] +use polars_utils::total_ord::TotalHash; use rayon::prelude::*; use crate::prelude::*; use crate::utils::{dtypes_to_schema, dtypes_to_supertype, try_get_supertype}; use crate::POOL; +#[cfg(feature = "object")] +pub(crate) struct AnyValueRows<'a> { + vals: Vec>, + width: usize, +} + +#[cfg(feature = "object")] +pub(crate) struct AnyValueRow<'a>(&'a [AnyValue<'a>]); + +#[cfg(feature = "object")] +impl<'a> AnyValueRows<'a> { + pub(crate) fn get(&'a self, i: usize) -> AnyValueRow<'a> { + let start = i * self.width; + let end = (i + 1) * self.width; + AnyValueRow(&self.vals[start..end]) + } +} + +#[cfg(feature = "object")] +impl TotalEq for AnyValueRow<'_> { + fn tot_eq(&self, other: &Self) -> bool { + let lhs = self.0; + let rhs = other.0; + + // Should only be used in that context. + debug_assert_eq!(lhs.len(), rhs.len()); + lhs.iter().zip(rhs.iter()).all(|(l, r)| l == r) + } +} + +#[cfg(feature = "object")] +impl TotalHash for AnyValueRow<'_> { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.0.iter().for_each(|av| av.hash(state)) + } +} + +impl DataFrame { + #[cfg(feature = "object")] + #[allow(clippy::wrong_self_convention)] + // Create indexable rows in a single allocation. + pub(crate) fn to_av_rows(&mut self) -> AnyValueRows<'_> { + self.as_single_chunk_par(); + let width = self.width(); + let size = width * self.height(); + let mut buf = vec![AnyValue::Null; size]; + for (col_i, s) in self.columns.iter().enumerate() { + match s.dtype() { + #[cfg(feature = "object")] + DataType::Object(_, _) => { + for row_i in 0..s.len() { + let av = s.get(row_i).unwrap(); + buf[row_i * width + col_i] = av + } + }, + _ => { + for (row_i, av) in s.iter().enumerate() { + buf[row_i * width + col_i] = av + } + }, + } + } + AnyValueRows { vals: buf, width } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Row<'a>(pub Vec>); diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 219b88f36f7d..861599cb91b6 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -272,9 +272,8 @@ impl SeriesTrait for SeriesWrap { if self.len() == 1 { return Ok(IdxCa::new_vec(self.name(), vec![0 as IdxSize])); } - // TODO! try row encoding let main_thread = POOL.current_thread_index().is_none(); - let groups = self.group_tuples(main_thread, false)?; + let groups = self.group_tuples(main_thread, true)?; let first = groups.take_group_firsts(); Ok(IdxCa::from_vec(self.name(), first)) } diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs index 5b9af065779f..f5d7b22a95aa 100644 --- a/crates/polars-utils/src/total_ord.rs +++ b/crates/polars-utils/src/total_ord.rs @@ -90,46 +90,46 @@ pub struct TotalOrdWrap(pub T); unsafe impl TransparentWrapper for TotalOrdWrap {} impl PartialOrd for TotalOrdWrap { - #[inline] + #[inline(always)] fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } - #[inline] + #[inline(always)] fn lt(&self, other: &Self) -> bool { self.0.tot_lt(&other.0) } - #[inline] + #[inline(always)] fn le(&self, other: &Self) -> bool { self.0.tot_le(&other.0) } - #[inline] + #[inline(always)] fn gt(&self, other: &Self) -> bool { self.0.tot_gt(&other.0) } - #[inline] + #[inline(always)] fn ge(&self, other: &Self) -> bool { self.0.tot_ge(&other.0) } } impl Ord for TotalOrdWrap { - #[inline] + #[inline(always)] fn cmp(&self, other: &Self) -> Ordering { self.0.tot_cmp(&other.0) } } impl PartialEq for TotalOrdWrap { - #[inline] + #[inline(always)] fn eq(&self, other: &Self) -> bool { self.0.tot_eq(&other.0) } - #[inline] + #[inline(always)] #[allow(clippy::partialeq_ne_impl)] fn ne(&self, other: &Self) -> bool { self.0.tot_ne(&other.0) @@ -139,7 +139,7 @@ impl PartialEq for TotalOrdWrap { impl Eq for TotalOrdWrap {} impl Hash for TotalOrdWrap { - #[inline] + #[inline(always)] fn hash(&self, state: &mut H) { self.0.tot_hash(state); } @@ -158,33 +158,33 @@ impl IsNull for TotalOrdWrap { const HAS_NULLS: bool = T::HAS_NULLS; type Inner = T::Inner; - #[inline] + #[inline(always)] fn is_null(&self) -> bool { self.0.is_null() } - #[inline] + #[inline(always)] fn unwrap_inner(self) -> Self::Inner { self.0.unwrap_inner() } } impl DirtyHash for f32 { - #[inline] + #[inline(always)] fn dirty_hash(&self) -> u64 { canonical_f32(*self).to_bits().dirty_hash() } } impl DirtyHash for f64 { - #[inline] + #[inline(always)] fn dirty_hash(&self) -> u64 { canonical_f64(*self).to_bits().dirty_hash() } } impl DirtyHash for TotalOrdWrap { - #[inline] + #[inline(always)] fn dirty_hash(&self) -> u64 { self.0.dirty_hash() } @@ -193,46 +193,46 @@ impl DirtyHash for TotalOrdWrap { macro_rules! impl_trivial_total { ($T: ty) => { impl TotalEq for $T { - #[inline] + #[inline(always)] fn tot_eq(&self, other: &Self) -> bool { self == other } - #[inline] + #[inline(always)] fn tot_ne(&self, other: &Self) -> bool { self != other } } impl TotalOrd for $T { - #[inline] + #[inline(always)] fn tot_cmp(&self, other: &Self) -> Ordering { self.cmp(other) } - #[inline] + #[inline(always)] fn tot_lt(&self, other: &Self) -> bool { self < other } - #[inline] + #[inline(always)] fn tot_gt(&self, other: &Self) -> bool { self > other } - #[inline] + #[inline(always)] fn tot_le(&self, other: &Self) -> bool { self <= other } - #[inline] + #[inline(always)] fn tot_ge(&self, other: &Self) -> bool { self >= other } } impl TotalHash for $T { - #[inline] + #[inline(always)] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -277,7 +277,7 @@ macro_rules! impl_float_eq_ord { } impl TotalOrd for $T { - #[inline] + #[inline(always)] fn tot_cmp(&self, other: &Self) -> Ordering { if self.tot_lt(other) { Ordering::Less @@ -288,22 +288,22 @@ macro_rules! impl_float_eq_ord { } } - #[inline] + #[inline(always)] fn tot_lt(&self, other: &Self) -> bool { !self.tot_ge(other) } - #[inline] + #[inline(always)] fn tot_gt(&self, other: &Self) -> bool { other.tot_lt(self) } - #[inline] + #[inline(always)] fn tot_le(&self, other: &Self) -> bool { other.tot_ge(self) } - #[inline] + #[inline(always)] fn tot_ge(&self, other: &Self) -> bool { // We consider all NaNs equal, and NaN is the largest possible // value. Thus if self is NaN we always return true. Otherwise @@ -320,7 +320,7 @@ impl_float_eq_ord!(f32); impl_float_eq_ord!(f64); impl TotalHash for f32 { - #[inline] + #[inline(always)] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -330,7 +330,7 @@ impl TotalHash for f32 { } impl TotalHash for f64 { - #[inline] + #[inline(always)] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -341,7 +341,7 @@ impl TotalHash for f64 { // Blanket implementations. impl TotalEq for Option { - #[inline] + #[inline(always)] fn tot_eq(&self, other: &Self) -> bool { match (self, other) { (None, None) => true, @@ -350,7 +350,7 @@ impl TotalEq for Option { } } - #[inline] + #[inline(always)] fn tot_ne(&self, other: &Self) -> bool { match (self, other) { (None, None) => false, @@ -361,7 +361,7 @@ impl TotalEq for Option { } impl TotalOrd for Option { - #[inline] + #[inline(always)] fn tot_cmp(&self, other: &Self) -> Ordering { match (self, other) { (None, None) => Ordering::Equal, @@ -371,7 +371,7 @@ impl TotalOrd for Option { } } - #[inline] + #[inline(always)] fn tot_lt(&self, other: &Self) -> bool { match (self, other) { (None, Some(_)) => true, @@ -380,12 +380,12 @@ impl TotalOrd for Option { } } - #[inline] + #[inline(always)] fn tot_gt(&self, other: &Self) -> bool { other.tot_lt(self) } - #[inline] + #[inline(always)] fn tot_le(&self, other: &Self) -> bool { match (self, other) { (Some(_), None) => false, @@ -394,13 +394,14 @@ impl TotalOrd for Option { } } - #[inline] + #[inline(always)] fn tot_ge(&self, other: &Self) -> bool { other.tot_le(self) } } impl TotalHash for Option { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -413,19 +414,19 @@ impl TotalHash for Option { } impl TotalEq for &T { - #[inline] + #[inline(always)] fn tot_eq(&self, other: &Self) -> bool { (*self).tot_eq(*other) } - #[inline] + #[inline(always)] fn tot_ne(&self, other: &Self) -> bool { (*self).tot_ne(*other) } } impl TotalHash for &T { - #[inline] + #[inline(always)] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -435,12 +436,14 @@ impl TotalHash for &T { } impl TotalEq for (T, U) { + #[inline] fn tot_eq(&self, other: &Self) -> bool { self.0.tot_eq(&other.0) && self.1.tot_eq(&other.1) } } impl TotalOrd for (T, U) { + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { self.0 .tot_cmp(&other.0) @@ -449,7 +452,7 @@ impl TotalOrd for (T, U) { } impl<'a> TotalHash for BytesHash<'a> { - #[inline] + #[inline(always)] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -459,7 +462,7 @@ impl<'a> TotalHash for BytesHash<'a> { } impl<'a> TotalEq for BytesHash<'a> { - #[inline] + #[inline(always)] fn tot_eq(&self, other: &Self) -> bool { self == other } From be4169796e6974252ada6e13cf6127d85c9bdb4f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Sat, 30 Mar 2024 16:41:48 +0100 Subject: [PATCH 16/30] refactor: make dsl immutable and cheap to clone (#15394) --- .../polars-lazy/src/physical_plan/exotic.rs | 19 +- crates/polars-plan/src/dsl/arity.rs | 10 +- crates/polars-plan/src/dsl/expr.rs | 72 ++++---- .../src/dsl/functions/syntactic_sugar.rs | 2 +- crates/polars-plan/src/dsl/meta.rs | 19 +- crates/polars-plan/src/dsl/mod.rs | 58 +++--- crates/polars-plan/src/dsl/name.rs | 4 +- crates/polars-plan/src/dsl/statistics.rs | 18 +- .../src/logical_plan/conversion.rs | 145 ++++++++------- .../polars-plan/src/logical_plan/iterator.rs | 157 ++++++++--------- .../optimizer/predicate_pushdown/mod.rs | 16 +- .../src/logical_plan/projection.rs | 166 +++++++----------- .../src/logical_plan/visitor/expr.rs | 58 +++++- crates/polars-plan/src/utils.rs | 9 +- crates/polars-sql/src/context.rs | 15 +- crates/polars-utils/src/functions.rs | 39 ++++ 16 files changed, 419 insertions(+), 388 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index 14fb5a5c3517..664cd2bfbb2d 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -4,19 +4,12 @@ use crate::physical_plan::planner::create_physical_expr; use crate::prelude::*; #[cfg(feature = "pivot")] -pub(crate) fn prepare_eval_expr(mut expr: Expr) -> Expr { - expr.mutate().apply(|e| match e { - Expr::Column(name) => { - *name = Arc::from(""); - true - }, - Expr::Nth(_) => { - *e = Expr::Column(Arc::from("")); - true - }, - _ => true, - }); - expr +pub(crate) fn prepare_eval_expr(expr: Expr) -> Expr { + expr.map_expr(|e| match e { + Expr::Column(_) => Expr::Column(Arc::from("")), + Expr::Nth(_) => Expr::Column(Arc::from("")), + e => e, + }) } pub(crate) fn prepare_expression_for_context( diff --git a/crates/polars-plan/src/dsl/arity.rs b/crates/polars-plan/src/dsl/arity.rs index 05ff22df52b0..9883936f6c10 100644 --- a/crates/polars-plan/src/dsl/arity.rs +++ b/crates/polars-plan/src/dsl/arity.rs @@ -139,17 +139,17 @@ pub fn when>(condition: E) -> When { pub fn ternary_expr(predicate: Expr, truthy: Expr, falsy: Expr) -> Expr { Expr::Ternary { - predicate: Box::new(predicate), - truthy: Box::new(truthy), - falsy: Box::new(falsy), + predicate: Arc::new(predicate), + truthy: Arc::new(truthy), + falsy: Arc::new(falsy), } } /// Compute `op(l, r)` (or equivalently `l op r`). `l` and `r` must have types compatible with the Operator. pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { Expr::BinaryExpr { - left: Box::new(l), + left: Arc::new(l), op, - right: Box::new(r), + right: Arc::new(r), } } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index cb88fc45be55..5e8a31dd65f2 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -12,30 +12,30 @@ use crate::prelude::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AggExpr { Min { - input: Box, + input: Arc, propagate_nans: bool, }, Max { - input: Box, + input: Arc, propagate_nans: bool, }, - Median(Box), - NUnique(Box), - First(Box), - Last(Box), - Mean(Box), - Implode(Box), + Median(Arc), + NUnique(Arc), + First(Arc), + Last(Arc), + Mean(Arc), + Implode(Arc), // include_nulls - Count(Box, bool), + Count(Arc, bool), Quantile { - expr: Box, - quantile: Box, + expr: Arc, + quantile: Arc, interpol: QuantileInterpolOptions, }, - Sum(Box), - AggGroups(Box), - Std(Box, u8), - Var(Box, u8), + Sum(Arc), + AggGroups(Arc), + Std(Arc, u8), + Var(Arc, u8), } impl AsRef for AggExpr { @@ -67,32 +67,32 @@ impl AsRef for AggExpr { #[must_use] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Expr { - Alias(Box, Arc), + Alias(Arc, Arc), Column(Arc), Columns(Vec), DtypeColumn(Vec), Literal(LiteralValue), BinaryExpr { - left: Box, + left: Arc, op: Operator, - right: Box, + right: Arc, }, Cast { - expr: Box, + expr: Arc, data_type: DataType, strict: bool, }, Sort { - expr: Box, + expr: Arc, options: SortOptions, }, Gather { - expr: Box, - idx: Box, + expr: Arc, + idx: Arc, returns_scalar: bool, }, SortBy { - expr: Box, + expr: Arc, by: Vec, descending: Vec, }, @@ -100,9 +100,9 @@ pub enum Expr { /// A ternary operation /// if true then "foo" else "bar" Ternary { - predicate: Box, - truthy: Box, - falsy: Box, + predicate: Arc, + truthy: Arc, + falsy: Arc, }, Function { /// function arguments @@ -111,29 +111,29 @@ pub enum Expr { function: FunctionExpr, options: FunctionOptions, }, - Explode(Box), + Explode(Arc), Filter { - input: Box, - by: Box, + input: Arc, + by: Arc, }, /// See postgres window functions Window { /// Also has the input. i.e. avg("foo") - function: Box, + function: Arc, partition_by: Vec, options: WindowType, }, Wildcard, Slice { - input: Box, + input: Arc, /// length is not yet known so we accept negative offsets - offset: Box, - length: Box, + offset: Arc, + length: Arc, }, /// Can be used in a select statement to exclude a column from selection - Exclude(Box, Vec), + Exclude(Arc, Vec), /// Set root name as Alias - KeepName(Box), + KeepName(Arc), Len, /// Take the nth column in the `DataFrame` Nth(i64), @@ -141,7 +141,7 @@ pub enum Expr { #[cfg_attr(feature = "serde", serde(skip))] RenameAlias { function: SpecialEq>, - expr: Box, + expr: Arc, }, AnonymousFunction { /// function arguments diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index df778ee60ee6..5315709da4cf 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -57,7 +57,7 @@ pub fn is_not_null(expr: Expr) -> Expr { /// nominal type of the column. pub fn cast(expr: Expr, data_type: DataType) -> Expr { Expr::Cast { - expr: Box::new(expr), + expr: Arc::new(expr), data_type, strict: false, } diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs index 28a554007a50..ac753024b8ce 100644 --- a/crates/polars-plan/src/dsl/meta.rs +++ b/crates/polars-plan/src/dsl/meta.rs @@ -41,22 +41,13 @@ impl MetaNameSpace { } /// Undo any renaming operation like `alias`, `keep_name`. - pub fn undo_aliases(mut self) -> Expr { - self.0.mutate().apply(|e| match e { + pub fn undo_aliases(self) -> Expr { + self.0.map_expr(|e| match e { Expr::Alias(input, _) | Expr::KeepName(input) - | Expr::RenameAlias { expr: input, .. } => { - // remove this node - *e = *input.clone(); - - // continue iteration - true - }, - // continue iteration - _ => true, - }); - - self.0 + | Expr::RenameAlias { expr: input, .. } => Arc::unwrap_or_clone(input), + e => e, + }) } /// Indicate if this expression expands to multiple expressions. diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 872153db3961..0d5c7f5025b6 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -166,7 +166,7 @@ impl Expr { /// Rename Column. pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), ColumnName::from(name)) + Expr::Alias(Arc::new(self), ColumnName::from(name)) } /// Run is_null operation on `Expr`. @@ -193,29 +193,29 @@ impl Expr { /// Get the number of unique values in the groups. pub fn n_unique(self) -> Self { - AggExpr::NUnique(Box::new(self)).into() + AggExpr::NUnique(Arc::new(self)).into() } /// Get the first value in the group. pub fn first(self) -> Self { - AggExpr::First(Box::new(self)).into() + AggExpr::First(Arc::new(self)).into() } /// Get the last value in the group. pub fn last(self) -> Self { - AggExpr::Last(Box::new(self)).into() + AggExpr::Last(Arc::new(self)).into() } /// Aggregate the group to a Series. pub fn implode(self) -> Self { - AggExpr::Implode(Box::new(self)).into() + AggExpr::Implode(Arc::new(self)).into() } /// Compute the quantile per group. pub fn quantile(self, quantile: Expr, interpol: QuantileInterpolOptions) -> Self { AggExpr::Quantile { - expr: Box::new(self), - quantile: Box::new(quantile), + expr: Arc::new(self), + quantile: Arc::new(quantile), interpol, } .into() @@ -223,7 +223,7 @@ impl Expr { /// Get the group indexes of the group by operation. pub fn agg_groups(self) -> Self { - AggExpr::AggGroups(Box::new(self)).into() + AggExpr::AggGroups(Arc::new(self)).into() } /// Alias for `explode`. @@ -233,16 +233,16 @@ impl Expr { /// Explode the String/List column. pub fn explode(self) -> Self { - Expr::Explode(Box::new(self)) + Expr::Explode(Arc::new(self)) } /// Slice the Series. /// `offset` may be negative. pub fn slice, F: Into>(self, offset: E, length: F) -> Self { Expr::Slice { - input: Box::new(self), - offset: Box::new(offset.into()), - length: Box::new(length.into()), + input: Arc::new(self), + offset: Arc::new(offset.into()), + length: Arc::new(length.into()), } } @@ -375,7 +375,7 @@ impl Expr { /// Throws an error if conversion had overflows. pub fn strict_cast(self, data_type: DataType) -> Self { Expr::Cast { - expr: Box::new(self), + expr: Arc::new(self), data_type, strict: true, } @@ -384,7 +384,7 @@ impl Expr { /// Cast expression to another data type. pub fn cast(self, data_type: DataType) -> Self { Expr::Cast { - expr: Box::new(self), + expr: Arc::new(self), data_type, strict: false, } @@ -393,8 +393,8 @@ impl Expr { /// Take the values by idx. pub fn gather>(self, idx: E) -> Self { Expr::Gather { - expr: Box::new(self), - idx: Box::new(idx.into()), + expr: Arc::new(self), + idx: Arc::new(idx.into()), returns_scalar: false, } } @@ -402,8 +402,8 @@ impl Expr { /// Take the values by a single index. pub fn get>(self, idx: E) -> Self { Expr::Gather { - expr: Box::new(self), - idx: Box::new(idx.into()), + expr: Arc::new(self), + idx: Arc::new(idx.into()), returns_scalar: true, } } @@ -411,7 +411,7 @@ impl Expr { /// Sort in increasing order. See [the eager implementation](Series::sort). pub fn sort(self, descending: bool) -> Self { Expr::Sort { - expr: Box::new(self), + expr: Arc::new(self), options: SortOptions { descending, ..Default::default() @@ -422,7 +422,7 @@ impl Expr { /// Sort with given options. pub fn sort_with(self, options: SortOptions) -> Self { Expr::Sort { - expr: Box::new(self), + expr: Arc::new(self), options, } } @@ -903,7 +903,7 @@ impl Expr { .map(|e| e.clone().into()) .collect(); Expr::Window { - function: Box::new(self), + function: Arc::new(self), partition_by, options: options.into(), } @@ -915,7 +915,7 @@ impl Expr { // not ignore it. let index_col = col(options.index_column.as_str()); Expr::Window { - function: Box::new(self), + function: Arc::new(self), partition_by: vec![index_col], options: WindowType::Rolling(options), } @@ -961,11 +961,11 @@ impl Expr { /// or /// Get counts of the group by operation. pub fn count(self) -> Self { - AggExpr::Count(Box::new(self), false).into() + AggExpr::Count(Arc::new(self), false).into() } pub fn len(self) -> Self { - AggExpr::Count(Box::new(self), true).into() + AggExpr::Count(Arc::new(self), true).into() } /// Get a mask of duplicated values. @@ -1037,8 +1037,8 @@ impl Expr { panic!("filter '*' not allowed, use LazyFrame::filter") }; Expr::Filter { - input: Box::new(self), - by: Box::new(predicate.into()), + input: Arc::new(self), + by: Arc::new(predicate.into()), } } @@ -1081,7 +1081,7 @@ impl Expr { let by = by.as_ref().iter().map(|e| e.clone().into()).collect(); let descending = descending.as_ref().to_vec(); Expr::SortBy { - expr: Box::new(self), + expr: Arc::new(self), by, descending, } @@ -1137,7 +1137,7 @@ impl Expr { .into_iter() .map(|s| Excluded::Name(ColumnName::from(s))) .collect(); - Expr::Exclude(Box::new(self), v) + Expr::Exclude(Arc::new(self), v) } pub fn exclude_dtype>(self, dtypes: D) -> Expr { @@ -1146,7 +1146,7 @@ impl Expr { .iter() .map(|dt| Excluded::Dtype(dt.clone())) .collect(); - Expr::Exclude(Box::new(self), v) + Expr::Exclude(Arc::new(self), v) } #[cfg(feature = "interpolate")] diff --git a/crates/polars-plan/src/dsl/name.rs b/crates/polars-plan/src/dsl/name.rs index 61fe7951e741..def56c87e87e 100644 --- a/crates/polars-plan/src/dsl/name.rs +++ b/crates/polars-plan/src/dsl/name.rs @@ -21,7 +21,7 @@ impl ExprNameNameSpace { /// } /// ``` pub fn keep(self) -> Expr { - Expr::KeepName(Box::new(self.0)) + Expr::KeepName(Arc::new(self.0)) } /// Define an alias by mapping a function over the original root column name. @@ -31,7 +31,7 @@ impl ExprNameNameSpace { { let function = SpecialEq::new(Arc::new(function) as Arc); Expr::RenameAlias { - expr: Box::new(self.0), + expr: Arc::new(self.0), function, } } diff --git a/crates/polars-plan/src/dsl/statistics.rs b/crates/polars-plan/src/dsl/statistics.rs index 20a63d1e2bf1..6220f6b88b58 100644 --- a/crates/polars-plan/src/dsl/statistics.rs +++ b/crates/polars-plan/src/dsl/statistics.rs @@ -3,18 +3,18 @@ use super::*; impl Expr { /// Standard deviation of the values of the Series. pub fn std(self, ddof: u8) -> Self { - AggExpr::Std(Box::new(self), ddof).into() + AggExpr::Std(Arc::new(self), ddof).into() } /// Variance of the values of the Series. pub fn var(self, ddof: u8) -> Self { - AggExpr::Var(Box::new(self), ddof).into() + AggExpr::Var(Arc::new(self), ddof).into() } /// Reduce groups to minimal value. pub fn min(self) -> Self { AggExpr::Min { - input: Box::new(self), + input: Arc::new(self), propagate_nans: false, } .into() @@ -23,7 +23,7 @@ impl Expr { /// Reduce groups to maximum value. pub fn max(self) -> Self { AggExpr::Max { - input: Box::new(self), + input: Arc::new(self), propagate_nans: false, } .into() @@ -32,7 +32,7 @@ impl Expr { /// Reduce groups to minimal value. pub fn nan_min(self) -> Self { AggExpr::Min { - input: Box::new(self), + input: Arc::new(self), propagate_nans: true, } .into() @@ -41,7 +41,7 @@ impl Expr { /// Reduce groups to maximum value. pub fn nan_max(self) -> Self { AggExpr::Max { - input: Box::new(self), + input: Arc::new(self), propagate_nans: true, } .into() @@ -49,17 +49,17 @@ impl Expr { /// Reduce groups to the mean value. pub fn mean(self) -> Self { - AggExpr::Mean(Box::new(self)).into() + AggExpr::Mean(Arc::new(self)).into() } /// Reduce groups to the median value. pub fn median(self) -> Self { - AggExpr::Median(Box::new(self)).into() + AggExpr::Median(Arc::new(self)).into() } /// Reduce groups to the sum of all the values. pub fn sum(self) -> Self { - AggExpr::Sum(Box::new(self)).into() + AggExpr::Sum(Arc::new(self)).into() } /// Compute the histogram of a dataset. diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 1cfb23b79bc4..3c1f6f65ab3a 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -66,17 +66,18 @@ fn to_aexprs(input: Vec, arena: &mut Arena, state: &mut ConversionS /// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation. #[recursive] fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionState) -> Node { + let owned = Arc::unwrap_or_clone; let v = match expr { - Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(*expr, arena, state)), + Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(owned(expr), arena, state)), Expr::Alias(e, name) => { if state.prune_alias { if state.output_name.is_none() && !state.ignore_alias { state.output_name = OutputName::Alias(name); } - to_aexpr_impl(*e, arena, state); + to_aexpr_impl(owned(e), arena, state); arena.pop().unwrap() } else { - AExpr::Alias(to_aexpr_impl(*e, arena, state), name) + AExpr::Alias(to_aexpr_impl(owned(e), arena, state), name) } }, Expr::Literal(lv) => { @@ -92,8 +93,8 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta AExpr::Column(name) }, Expr::BinaryExpr { left, op, right } => { - let l = to_aexpr_impl(*left, arena, state); - let r = to_aexpr_impl(*right, arena, state); + let l = to_aexpr_impl(owned(left), arena, state); + let r = to_aexpr_impl(owned(right), arena, state); AExpr::BinaryExpr { left: l, op, @@ -105,7 +106,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta data_type, strict, } => AExpr::Cast { - expr: to_aexpr_impl(*expr, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), data_type, strict, }, @@ -114,12 +115,12 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta idx, returns_scalar, } => AExpr::Gather { - expr: to_aexpr_impl(*expr, arena, state), - idx: to_aexpr_impl(*idx, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), + idx: to_aexpr_impl(owned(idx), arena, state), returns_scalar, }, Expr::Sort { expr, options } => AExpr::Sort { - expr: to_aexpr_impl(*expr, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), options, }, Expr::SortBy { @@ -127,7 +128,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta by, descending, } => AExpr::SortBy { - expr: to_aexpr_impl(*expr, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), by: by .into_iter() .map(|e| to_aexpr_impl(e, arena, state)) @@ -135,8 +136,8 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta descending, }, Expr::Filter { input, by } => AExpr::Filter { - input: to_aexpr_impl(*input, arena, state), - by: to_aexpr_impl(*by, arena, state), + input: to_aexpr_impl(owned(input), arena, state), + by: to_aexpr_impl(owned(by), arena, state), }, Expr::Agg(agg) => { let a_agg = match agg { @@ -144,38 +145,48 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta input, propagate_nans, } => AAggExpr::Min { - input: to_aexpr_impl(*input, arena, state), + input: to_aexpr_impl(owned(input), arena, state), propagate_nans, }, AggExpr::Max { input, propagate_nans, } => AAggExpr::Max { - input: to_aexpr_impl(*input, arena, state), + input: to_aexpr_impl(owned(input), arena, state), propagate_nans, }, - AggExpr::Median(expr) => AAggExpr::Median(to_aexpr_impl(*expr, arena, state)), - AggExpr::NUnique(expr) => AAggExpr::NUnique(to_aexpr_impl(*expr, arena, state)), - AggExpr::First(expr) => AAggExpr::First(to_aexpr_impl(*expr, arena, state)), - AggExpr::Last(expr) => AAggExpr::Last(to_aexpr_impl(*expr, arena, state)), - AggExpr::Mean(expr) => AAggExpr::Mean(to_aexpr_impl(*expr, arena, state)), - AggExpr::Implode(expr) => AAggExpr::Implode(to_aexpr_impl(*expr, arena, state)), + AggExpr::Median(expr) => AAggExpr::Median(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::NUnique(expr) => { + AAggExpr::NUnique(to_aexpr_impl(owned(expr), arena, state)) + }, + AggExpr::First(expr) => AAggExpr::First(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Last(expr) => AAggExpr::Last(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Mean(expr) => AAggExpr::Mean(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Implode(expr) => { + AAggExpr::Implode(to_aexpr_impl(owned(expr), arena, state)) + }, AggExpr::Count(expr, include_nulls) => { - AAggExpr::Count(to_aexpr_impl(*expr, arena, state), include_nulls) + AAggExpr::Count(to_aexpr_impl(owned(expr), arena, state), include_nulls) }, AggExpr::Quantile { expr, quantile, interpol, } => AAggExpr::Quantile { - expr: to_aexpr_impl(*expr, arena, state), - quantile: to_aexpr_impl(*quantile, arena, state), + expr: to_aexpr_impl(owned(expr), arena, state), + quantile: to_aexpr_impl(owned(quantile), arena, state), interpol, }, - AggExpr::Sum(expr) => AAggExpr::Sum(to_aexpr_impl(*expr, arena, state)), - AggExpr::Std(expr, ddof) => AAggExpr::Std(to_aexpr_impl(*expr, arena, state), ddof), - AggExpr::Var(expr, ddof) => AAggExpr::Var(to_aexpr_impl(*expr, arena, state), ddof), - AggExpr::AggGroups(expr) => AAggExpr::AggGroups(to_aexpr_impl(*expr, arena, state)), + AggExpr::Sum(expr) => AAggExpr::Sum(to_aexpr_impl(owned(expr), arena, state)), + AggExpr::Std(expr, ddof) => { + AAggExpr::Std(to_aexpr_impl(owned(expr), arena, state), ddof) + }, + AggExpr::Var(expr, ddof) => { + AAggExpr::Var(to_aexpr_impl(owned(expr), arena, state), ddof) + }, + AggExpr::AggGroups(expr) => { + AAggExpr::AggGroups(to_aexpr_impl(owned(expr), arena, state)) + }, }; AExpr::Agg(a_agg) }, @@ -185,9 +196,9 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta falsy, } => { // Truthy must be resolved first to get the lhs name first set. - let t = to_aexpr_impl(*truthy, arena, state); - let p = to_aexpr_impl(*predicate, arena, state); - let f = to_aexpr_impl(*falsy, arena, state); + let t = to_aexpr_impl(owned(truthy), arena, state); + let p = to_aexpr_impl(owned(predicate), arena, state); + let f = to_aexpr_impl(owned(falsy), arena, state); AExpr::Ternary { predicate: p, truthy: t, @@ -228,7 +239,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta partition_by, options, } => AExpr::Window { - function: to_aexpr_impl(*function, arena, state), + function: to_aexpr_impl(owned(function), arena, state), partition_by: to_aexprs(partition_by, arena, state), options, }, @@ -237,9 +248,9 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta offset, length, } => AExpr::Slice { - input: to_aexpr_impl(*input, arena, state), - offset: to_aexpr_impl(*offset, arena, state), - length: to_aexpr_impl(*length, arena, state), + input: to_aexpr_impl(owned(input), arena, state), + offset: to_aexpr_impl(owned(offset), arena, state), + length: to_aexpr_impl(owned(length), arena, state), }, Expr::Len => { if state.output_name.is_none() { @@ -482,10 +493,10 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = expr_arena.get(node).clone(); match expr { - AExpr::Explode(node) => Expr::Explode(Box::new(node_to_expr(node, expr_arena))), + AExpr::Explode(node) => Expr::Explode(Arc::new(node_to_expr(node, expr_arena))), AExpr::Alias(expr, name) => { let exp = node_to_expr(expr, expr_arena); - Expr::Alias(Box::new(exp), name) + Expr::Alias(Arc::new(exp), name) }, AExpr::Column(a) => Expr::Column(a), AExpr::Literal(s) => Expr::Literal(s), @@ -493,9 +504,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let l = node_to_expr(left, expr_arena); let r = node_to_expr(right, expr_arena); Expr::BinaryExpr { - left: Box::new(l), + left: Arc::new(l), op, - right: Box::new(r), + right: Arc::new(r), } }, AExpr::Cast { @@ -505,7 +516,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { } => { let exp = node_to_expr(expr, expr_arena); Expr::Cast { - expr: Box::new(exp), + expr: Arc::new(exp), data_type, strict, } @@ -513,7 +524,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AExpr::Sort { expr, options } => { let exp = node_to_expr(expr, expr_arena); Expr::Sort { - expr: Box::new(exp), + expr: Arc::new(exp), options, } }, @@ -525,8 +536,8 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = node_to_expr(expr, expr_arena); let idx = node_to_expr(idx, expr_arena); Expr::Gather { - expr: Box::new(expr), - idx: Box::new(idx), + expr: Arc::new(expr), + idx: Arc::new(idx), returns_scalar, } }, @@ -541,7 +552,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { .map(|node| node_to_expr(*node, expr_arena)) .collect(); Expr::SortBy { - expr: Box::new(expr), + expr: Arc::new(expr), by, descending, } @@ -550,8 +561,8 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let input = node_to_expr(input, expr_arena); let by = node_to_expr(by, expr_arena); Expr::Filter { - input: Box::new(input), - by: Box::new(by), + input: Arc::new(input), + by: Arc::new(by), } }, AExpr::Agg(agg) => match agg { @@ -561,7 +572,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { } => { let exp = node_to_expr(input, expr_arena); AggExpr::Min { - input: Box::new(exp), + input: Arc::new(exp), propagate_nans, } .into() @@ -572,7 +583,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { } => { let exp = node_to_expr(input, expr_arena); AggExpr::Max { - input: Box::new(exp), + input: Arc::new(exp), propagate_nans, } .into() @@ -580,27 +591,27 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AAggExpr::Median(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Median(Box::new(exp)).into() + AggExpr::Median(Arc::new(exp)).into() }, AAggExpr::NUnique(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::NUnique(Box::new(exp)).into() + AggExpr::NUnique(Arc::new(exp)).into() }, AAggExpr::First(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::First(Box::new(exp)).into() + AggExpr::First(Arc::new(exp)).into() }, AAggExpr::Last(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Last(Box::new(exp)).into() + AggExpr::Last(Arc::new(exp)).into() }, AAggExpr::Mean(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Mean(Box::new(exp)).into() + AggExpr::Mean(Arc::new(exp)).into() }, AAggExpr::Implode(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Implode(Box::new(exp)).into() + AggExpr::Implode(Arc::new(exp)).into() }, AAggExpr::Quantile { expr, @@ -610,31 +621,31 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let expr = node_to_expr(expr, expr_arena); let quantile = node_to_expr(quantile, expr_arena); AggExpr::Quantile { - expr: Box::new(expr), - quantile: Box::new(quantile), + expr: Arc::new(expr), + quantile: Arc::new(quantile), interpol, } .into() }, AAggExpr::Sum(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Sum(Box::new(exp)).into() + AggExpr::Sum(Arc::new(exp)).into() }, AAggExpr::Std(expr, ddof) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Std(Box::new(exp), ddof).into() + AggExpr::Std(Arc::new(exp), ddof).into() }, AAggExpr::Var(expr, ddof) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::Var(Box::new(exp), ddof).into() + AggExpr::Var(Arc::new(exp), ddof).into() }, AAggExpr::AggGroups(expr) => { let exp = node_to_expr(expr, expr_arena); - AggExpr::AggGroups(Box::new(exp)).into() + AggExpr::AggGroups(Arc::new(exp)).into() }, AAggExpr::Count(expr, include_nulls) => { let expr = node_to_expr(expr, expr_arena); - AggExpr::Count(Box::new(expr), include_nulls).into() + AggExpr::Count(Arc::new(expr), include_nulls).into() }, }, AExpr::Ternary { @@ -647,9 +658,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let f = node_to_expr(falsy, expr_arena); Expr::Ternary { - predicate: Box::new(p), - truthy: Box::new(t), - falsy: Box::new(f), + predicate: Arc::new(p), + truthy: Arc::new(t), + falsy: Arc::new(f), } }, AExpr::AnonymousFunction { @@ -677,7 +688,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { partition_by, options, } => { - let function = Box::new(node_to_expr(function, expr_arena)); + let function = Arc::new(node_to_expr(function, expr_arena)); let partition_by = nodes_to_exprs(&partition_by, expr_arena); Expr::Window { function, @@ -690,9 +701,9 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { offset, length, } => Expr::Slice { - input: Box::new(node_to_expr(input, expr_arena)), - offset: Box::new(node_to_expr(offset, expr_arena)), - length: Box::new(node_to_expr(length, expr_arena)), + input: Arc::new(node_to_expr(input, expr_arena)), + offset: Arc::new(node_to_expr(offset, expr_arena)), + length: Arc::new(node_to_expr(length, expr_arena)), }, AExpr::Len => Expr::Len, AExpr::Nth(i) => Expr::Nth(i), diff --git a/crates/polars-plan/src/logical_plan/iterator.rs b/crates/polars-plan/src/logical_plan/iterator.rs index 611f08badd83..9b1ac7ecb01e 100644 --- a/crates/polars-plan/src/logical_plan/iterator.rs +++ b/crates/polars-plan/src/logical_plan/iterator.rs @@ -1,55 +1,58 @@ -use arrow::legacy::error::PolarsResult; +use std::sync::Arc; + +use polars_core::error::PolarsResult; use polars_utils::idx_vec::UnitVec; use polars_utils::unitvec; +use visitor::{RewritingVisitor, TreeWalker}; use crate::prelude::*; macro_rules! push_expr { - ($current_expr:expr, $push:ident, $iter:ident) => {{ + ($current_expr:expr, $c:ident, $push:ident, $push_owned:ident, $iter:ident) => {{ use Expr::*; match $current_expr { Nth(_) | Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) | Len => {}, - Alias(e, _) => $push(e), + Alias(e, _) => $push($c, e), BinaryExpr { left, op: _, right } => { // reverse order so that left is popped first - $push(right); - $push(left); + $push($c, right); + $push($c, left); }, - Cast { expr, .. } => $push(expr), - Sort { expr, .. } => $push(expr), + Cast { expr, .. } => $push($c, expr), + Sort { expr, .. } => $push($c, expr), Gather { expr, idx, .. } => { - $push(idx); - $push(expr); + $push($c, idx); + $push($c, expr); }, Filter { input, by } => { - $push(by); + $push($c, by); // latest, so that it is popped first - $push(input); + $push($c, input); }, SortBy { expr, by, .. } => { for e in by { - $push(e) + $push_owned($c, e) } // latest, so that it is popped first - $push(expr); + $push($c, expr); }, Agg(agg_e) => { use AggExpr::*; match agg_e { - Max { input, .. } => $push(input), - Min { input, .. } => $push(input), - Mean(e) => $push(e), - Median(e) => $push(e), - NUnique(e) => $push(e), - First(e) => $push(e), - Last(e) => $push(e), - Implode(e) => $push(e), - Count(e, _) => $push(e), - Quantile { expr, .. } => $push(expr), - Sum(e) => $push(e), - AggGroups(e) => $push(e), - Std(e, _) => $push(e), - Var(e, _) => $push(e), + Max { input, .. } => $push($c, input), + Min { input, .. } => $push($c, input), + Mean(e) => $push($c, e), + Median(e) => $push($c, e), + NUnique(e) => $push($c, e), + First(e) => $push($c, e), + Last(e) => $push($c, e), + Implode(e) => $push($c, e), + Count(e, _) => $push($c, e), + Quantile { expr, .. } => $push($c, expr), + Sum(e) => $push($c, e), + AggGroups(e) => $push($c, e), + Std(e, _) => $push($c, e), + Var(e, _) => $push($c, e), } }, Ternary { @@ -57,40 +60,40 @@ macro_rules! push_expr { falsy, predicate, } => { - $push(predicate); - $push(falsy); + $push($c, predicate); + $push($c, falsy); // latest, so that it is popped first - $push(truthy); + $push($c, truthy); }, // we iterate in reverse order, so that the lhs is popped first and will be found // as the root columns/ input columns by `_suffix` and `_keep_name` etc. - AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push(e)), - Function { input, .. } => input.$iter().rev().for_each(|e| $push(e)), - Explode(e) => $push(e), + AnonymousFunction { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)), + Function { input, .. } => input.$iter().rev().for_each(|e| $push_owned($c, e)), + Explode(e) => $push($c, e), Window { function, partition_by, .. } => { for e in partition_by.into_iter().rev() { - $push(e) + $push_owned($c, e) } // latest so that it is popped first - $push(function); + $push($c, function); }, Slice { input, offset, length, } => { - $push(length); - $push(offset); + $push($c, length); + $push($c, offset); // latest, so that it is popped first - $push(input); + $push($c, input); }, - Exclude(e, _) => $push(e), - KeepName(e) => $push(e), - RenameAlias { expr, .. } => $push(expr), + Exclude(e, _) => $push($c, e), + KeepName(e) => $push($c, e), + RenameAlias { expr, .. } => $push($c, expr), SubPlan { .. } => {}, // pass Selector(_) => {}, @@ -98,47 +101,6 @@ macro_rules! push_expr { }}; } -impl Expr { - /// Expr::mutate().apply(fn()) - pub fn mutate(&mut self) -> ExprMut { - let stack = unitvec!(self); - ExprMut { stack } - } -} - -pub struct ExprMut<'a> { - stack: UnitVec<&'a mut Expr>, -} - -impl<'a> ExprMut<'a> { - /// - /// # Arguments - /// * `f` - A function that may mutate an expression. If the function returns `true` iteration - /// continues. - pub fn apply(&mut self, mut f: F) - where - F: FnMut(&mut Expr) -> bool, - { - let _ = self.try_apply(|e| Ok(f(e))); - } - - pub fn try_apply(&mut self, mut f: F) -> PolarsResult<()> - where - F: FnMut(&mut Expr) -> PolarsResult, - { - while let Some(current_expr) = self.stack.pop() { - // the order is important, we first modify the Expr - // before we push its children on the stack. - // The modification can make the children invalid. - if !f(current_expr)? { - break; - } - current_expr.nodes_mut(&mut self.stack) - } - Ok(()) - } -} - pub struct ExprIter<'a> { stack: UnitVec<&'a Expr>, } @@ -154,15 +116,36 @@ impl<'a> Iterator for ExprIter<'a> { } } +pub struct ExprMapper { + f: F, +} + +impl PolarsResult> RewritingVisitor for ExprMapper { + type Node = Expr; + + fn mutate(&mut self, node: Self::Node) -> PolarsResult { + (self.f)(node) + } +} + impl Expr { pub fn nodes<'a>(&'a self, container: &mut UnitVec<&'a Expr>) { - let mut push = |e: &'a Expr| container.push(e); - push_expr!(self, push, iter); + let push = |c: &mut UnitVec<&'a Expr>, e: &'a Expr| c.push(e); + push_expr!(self, container, push, push, iter); + } + + pub fn nodes_owned(self, container: &mut UnitVec) { + let push_arc = |c: &mut UnitVec, e: Arc| c.push(Arc::unwrap_or_clone(e)); + let push_owned = |c: &mut UnitVec, e: Expr| c.push(e); + push_expr!(self, container, push_arc, push_owned, into_iter); + } + + pub fn map_expr Self>(self, mut f: F) -> Self { + self.rewrite(&mut ExprMapper { f: |e| Ok(f(e)) }).unwrap() } - pub fn nodes_mut<'a>(&'a mut self, container: &mut UnitVec<&'a mut Expr>) { - let mut push = |e: &'a mut Expr| container.push(e); - push_expr!(self, push, iter_mut); + pub fn try_map_expr PolarsResult>(self, f: F) -> PolarsResult { + self.rewrite(&mut ExprMapper { f }) } } diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 679b58740b1c..c3bf2d2a2ea3 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -123,13 +123,15 @@ impl<'a> PredicatePushDown<'a> { if needs_rename { // TODO! Do this directly on AExpr. let mut new_expr = node_to_expr(e.node(), expr_arena); - new_expr.mutate().apply(|e| { - if let Expr::Column(name) = e { - if let Some(rename_to) = alias_rename_map.get(name) { - *name = rename_to.clone(); - }; - }; - true + new_expr = new_expr.map_expr(|e| match e { + Expr::Column(name) => { + if let Some(rename_to) = alias_rename_map.get(&*name) { + Expr::Column(rename_to.clone()) + } else { + Expr::Column(name) + } + }, + e => e, }); let predicate = to_aexpr(new_expr, expr_arena); e.set_node(predicate); diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index 343e0d0498e3..6450c4822bb5 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -5,33 +5,20 @@ use super::*; /// This replace the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. -pub(super) fn replace_wildcard_with_column(mut expr: Expr, column_name: Arc) -> Expr { - expr.mutate().apply(|e| { - match e { - Expr::Wildcard => { - *e = Expr::Column(column_name.clone()); - }, - Expr::Exclude(input, _) => { - *e = replace_wildcard_with_column(std::mem::take(input), column_name.clone()); - }, - _ => {}, - } - // always keep iterating all inputs - true - }); - expr +pub(super) fn replace_wildcard_with_column(expr: Expr, column_name: Arc) -> Expr { + expr.map_expr(|e| match e { + Expr::Wildcard => Expr::Column(column_name.clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) } #[cfg(feature = "regex")] -fn remove_exclude(mut expr: Expr) -> Expr { - expr.mutate().apply(|e| { - if let Expr::Exclude(input, _) = e { - *e = remove_exclude(std::mem::take(input)); - } - // always keep iterating all inputs - true - }); - expr +fn remove_exclude(expr: Expr) -> Expr { + expr.map_expr(|e| match e { + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) } fn rewrite_special_aliases(expr: Expr) -> PolarsResult { @@ -79,22 +66,22 @@ fn replace_wildcard( Ok(()) } -fn replace_nth(expr: &mut Expr, schema: &Schema) { - expr.mutate().apply(|e| match e { - Expr::Nth(i) => { +fn replace_nth(expr: Expr, schema: &Schema) -> Expr { + expr.map_expr(|e| { + if let Expr::Nth(i) = e { match i.negative_to_usize(schema.len()) { None => { - let name = if *i == 0 { "first" } else { "last" }; - *e = Expr::Column(ColumnName::from(name)); + let name = if i == 0 { "first" } else { "last" }; + Expr::Column(ColumnName::from(name)) }, Some(idx) => { let (name, _dtype) = schema.get_at_index(idx).unwrap(); - *e = Expr::Column(ColumnName::from(&**name)) + Expr::Column(ColumnName::from(&**name)) }, } - true - }, - _ => true, + } else { + e + } }) } @@ -114,12 +101,11 @@ fn expand_regex( if re.is_match(name) && !exclude.contains(name.as_str()) { let mut new_expr = remove_exclude(expr.clone()); - new_expr.mutate().apply(|e| match &e { + new_expr = new_expr.map_expr(|e| match e { Expr::Column(pat) if pat.as_ref() == pattern => { - *e = Expr::Column(ColumnName::from(name.as_str())); - true + Expr::Column(ColumnName::from(name.as_str())) }, - _ => true, + e => e, }); let new_expr = rewrite_special_aliases(new_expr)?; @@ -203,21 +189,12 @@ fn expand_columns( /// This replaces the dtypes Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. -pub(super) fn replace_dtype_with_column(mut expr: Expr, column_name: Arc) -> Expr { - expr.mutate().apply(|e| { - match e { - Expr::DtypeColumn(_) => { - *e = Expr::Column(column_name.clone()); - }, - Expr::Exclude(input, _) => { - *e = replace_dtype_with_column(std::mem::take(input), column_name.clone()); - }, - _ => {}, - } - // always keep iterating all inputs - true - }); - expr +pub(super) fn replace_dtype_with_column(expr: Expr, column_name: Arc) -> Expr { + expr.map_expr(|e| match e { + Expr::DtypeColumn(_) => Expr::Column(column_name.clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }) } /// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the @@ -228,26 +205,18 @@ pub(super) fn replace_columns_with_column( column_name: &str, ) -> (Expr, bool) { let mut is_valid = true; - expr.mutate().apply(|e| { - match e { - Expr::Columns(members) => { - // `col([a, b]) + col([c, d])` - if members == names { - *e = Expr::Column(ColumnName::from(column_name)); - } else { - is_valid = false; - } - }, - Expr::Exclude(input, _) => { - let (new_expr, new_expr_valid) = - replace_columns_with_column(std::mem::take(input), names, column_name); - *e = new_expr; - is_valid &= new_expr_valid; - }, - _ => {}, - } - // always keep iterating all inputs - true + expr = expr.map_expr(|e| match e { + Expr::Columns(members) => { + // `col([a, b]) + col([c, d])` + if members == names { + Expr::Column(ColumnName::from(column_name)) + } else { + is_valid = false; + Expr::Columns(members) + } + }, + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, }); (expr, is_valid) } @@ -363,18 +332,16 @@ fn prepare_excluded( } // functions can have col(["a", "b"]) or col(String) as inputs -fn expand_function_inputs(mut expr: Expr, schema: &Schema) -> Expr { - expr.mutate().apply(|e| match e { +fn expand_function_inputs(expr: Expr, schema: &Schema) -> Expr { + expr.map_expr(|mut e| match &mut e { Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } if options.input_wildcard_expansion => { - *input = rewrite_projections(input.clone(), schema, &[]).unwrap(); - // continue iteration, there might be more functions. - true + *input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap(); + e }, - _ => true, - }); - expr + _ => e, + }) } /// this is determined in type coercion @@ -460,7 +427,7 @@ pub(crate) fn rewrite_projections( let mut flags = find_flags(&expr); if flags.has_selector { - replace_selector(&mut expr, schema, keys)?; + expr = replace_selector(expr, schema, keys)?; // the selector is replaced with Expr::Columns flags.multiple_columns = true; } @@ -475,21 +442,19 @@ pub(crate) fn rewrite_projections( // them up there. if flags.replace_fill_null_type { for e in &mut result[result_offset..] { - e.mutate().apply(|e| { + *e = e.clone().map_expr(|mut e| { if let Expr::Function { input, function: FunctionExpr::FillNull { super_type }, .. - } = e + } = &mut e { if let Some(new_st) = early_supertype(input, schema) { *super_type = new_st; } } - - // continue iteration - true - }) + e + }); } } } @@ -504,7 +469,7 @@ fn replace_and_add_to_results( keys: &[Expr], ) -> PolarsResult<()> { if flags.has_nth { - replace_nth(&mut expr, schema); + expr = replace_nth(expr, schema); } // has multiple column names @@ -603,20 +568,18 @@ fn replace_selector_inner( Ok(()) } -fn replace_selector(expr: &mut Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult<()> { - // first pass we replace the selectors - // with Expr::Columns - // we expand the `to_add` columns - // and then subtract the `to_subtract` columns - expr.mutate().try_apply(|e| match e { - Expr::Selector(s) => { +fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult { + // First pass we replace the selectors with Expr::Columns, we expand the `to_add` columns + // and then subtract the `to_subtract` columns. + expr.try_map_expr(|e| match e { + Expr::Selector(mut s) => { let mut swapped = Selector::Root(Box::new(Expr::Wildcard)); - std::mem::swap(s, &mut swapped); + std::mem::swap(&mut s, &mut swapped); let mut members = PlIndexSet::new(); replace_selector_inner(swapped, &mut members, &mut vec![], schema, keys)?; - *e = Expr::Columns( + Ok(Expr::Columns( members .into_iter() .map(|e| { @@ -626,11 +589,8 @@ fn replace_selector(expr: &mut Expr, schema: &Schema, keys: &[Expr]) -> PolarsRe name.to_string() }) .collect(), - ); - - Ok(true) + )) }, - _ => Ok(true), - })?; - Ok(()) + e => Ok(e), + }) } diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index 8faf4e67fcce..e75dee347dd3 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -24,8 +24,62 @@ impl TreeWalker for Expr { Ok(VisitRecursion::Continue) } - fn map_children(self, _op: &mut dyn FnMut(Self) -> PolarsResult) -> PolarsResult { - todo!() + fn map_children(self, mut f: &mut dyn FnMut(Self) -> PolarsResult) -> PolarsResult { + use polars_utils::functions::try_arc_map as am; + use AggExpr::*; + use Expr::*; + #[rustfmt::skip] + let ret = match self { + Alias(l, r) => Alias(am(l, f)?, r), + Column(_) => self, + Columns(_) => self, + DtypeColumn(_) => self, + Literal(_) => self, + BinaryExpr { left, op, right } => { + BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?} + }, + Cast { expr, data_type, strict } => Cast { expr: am(expr, f)?, data_type, strict }, + Sort { expr, options } => Sort { expr: am(expr, f)?, options }, + Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar }, + SortBy { expr, by, descending } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, descending }, + Agg(agg_expr) => Agg(match agg_expr { + Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans }, + Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans }, + Median(x) => Median(am(x, f)?), + NUnique(x) => NUnique(am(x, f)?), + First(x) => First(am(x, f)?), + Last(x) => Last(am(x, f)?), + Mean(x) => Mean(am(x, f)?), + Implode(x) => Implode(am(x, f)?), + Count(x, nulls) => Count(am(x, f)?, nulls), + Quantile { expr, quantile, interpol } => Quantile { expr: am(expr, &mut f)?, quantile: am(quantile, f)?, interpol }, + Sum(x) => Sum(am(x, f)?), + AggGroups(x) => AggGroups(am(x, f)?), + Std(x, ddf) => Std(am(x, f)?, ddf), + Var(x, ddf) => Var(am(x, f)?, ddf), + }), + Ternary { predicate, truthy, falsy } => Ternary { predicate: am(predicate, &mut f)?, truthy: am(truthy, &mut f)?, falsy: am(falsy, f)? }, + Function { input, function, options } => Function { input: input.into_iter().map(f).collect::>()?, function, options }, + Explode(expr) => Explode(am(expr, f)?), + Filter { input, by } => Filter { input: am(input, &mut f)?, by: am(by, f)? }, + Window { function, partition_by, options } => { + let partition_by = partition_by.into_iter().map(&mut f).collect::>()?; + Window { function: am(function, f)?, partition_by, options } + }, + Wildcard => Wildcard, + Slice { input, offset, length } => Slice { input: am(input, &mut f)?, offset: am(offset, &mut f)?, length: am(length, f)? }, + Exclude(expr, excluded) => Exclude(am(expr, f)?, excluded), + KeepName(expr) => KeepName(am(expr, f)?), + Len => Len, + Nth(_) => self, + RenameAlias { function, expr } => RenameAlias { function, expr: am(expr, f)? }, + AnonymousFunction { input, function, output_type, options } => { + AnonymousFunction { input: input.into_iter().map(f).collect::>()?, function, output_type, options } + }, + SubPlan(_, _) => self, + Selector(_) => self, + }; + Ok(ret) } } diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index ff578ea599f7..258d42b4fc03 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -277,12 +277,9 @@ pub(crate) fn rename_matching_aexpr_leaf_names( if leaves.any(|node| matches!(arena.get(node.0), AExpr::Column(name) if &**name == current)) { // we convert to expression as we cannot easily copy the aexpr. let mut new_expr = node_to_expr(node, arena); - new_expr.mutate().apply(|e| match e { - Expr::Column(name) if &**name == current => { - *name = ColumnName::from(new_name); - true - }, - _ => true, + new_expr = new_expr.map_expr(|e| match e { + Expr::Column(name) if &*name == current => Expr::Column(ColumnName::from(new_name)), + e => e, }); to_aexpr(new_expr, arena) } else { diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 4e84219c9963..9a9963ed3259 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -525,15 +525,16 @@ impl SQLContext { fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame { let mut contexts = vec![]; for expr in exprs { - expr.mutate().apply(|e| { - if let Expr::SubPlan(lp, names) = e { - contexts.push(::from((***lp).clone())); - + *expr = expr.clone().map_expr(|e| match e { + Expr::SubPlan(lp, names) => { + contexts.push(::from((**lp).clone())); if names.len() == 1 { - *e = Expr::Column(names[0].as_str().into()); + Expr::Column(names[0].as_str().into()) + } else { + Expr::SubPlan(lp, names) } - }; - true + }, + e => e, }) } diff --git a/crates/polars-utils/src/functions.rs b/crates/polars-utils/src/functions.rs index 4ff1d724cefb..528bae5ed291 100644 --- a/crates/polars-utils/src/functions.rs +++ b/crates/polars-utils/src/functions.rs @@ -1,4 +1,6 @@ +use std::mem::MaybeUninit; use std::ops::Range; +use std::sync::Arc; // The ith portion of a range split in k (as equal as possible) parts. #[inline(always)] @@ -23,3 +25,40 @@ pub fn flatten>(bufs: &[R], len: Option) -> Vec T>(mut arc: Arc, mut f: F) -> Arc { + unsafe { + // Make the Arc unique (cloning if necessary). + Arc::make_mut(&mut arc); + + // If f panics we must be able to drop the Arc without assuming it is initialized. + let mut uninit_arc = Arc::from_raw(Arc::into_raw(arc).cast::>()); + + // Replace the value inside the arc. + let ptr = Arc::get_mut(&mut uninit_arc).unwrap_unchecked() as *mut MaybeUninit; + *ptr = MaybeUninit::new(f(ptr.read().assume_init())); + + // Now the Arc is properly initialized again. + Arc::from_raw(Arc::into_raw(uninit_arc).cast::()) + } +} + +pub fn try_arc_map Result>( + mut arc: Arc, + mut f: F, +) -> Result, E> { + unsafe { + // Make the Arc unique (cloning if necessary). + Arc::make_mut(&mut arc); + + // If f panics we must be able to drop the Arc without assuming it is initialized. + let mut uninit_arc = Arc::from_raw(Arc::into_raw(arc).cast::>()); + + // Replace the value inside the arc. + let ptr = Arc::get_mut(&mut uninit_arc).unwrap_unchecked() as *mut MaybeUninit; + *ptr = MaybeUninit::new(f(ptr.read().assume_init())?); + + // Now the Arc is properly initialized again. + Ok(Arc::from_raw(Arc::into_raw(uninit_arc).cast::())) + } +} From 345ca75baf2df276de017e2270515252b9e54a34 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 30 Mar 2024 16:42:01 +0100 Subject: [PATCH 17/30] feat: CSV-writer escape carriage return (#15399) --- crates/polars-io/src/csv/write_impl.rs | 28 +++++++++++++++++++------- py-polars/tests/unit/io/test_csv.py | 8 ++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/crates/polars-io/src/csv/write_impl.rs b/crates/polars-io/src/csv/write_impl.rs index db79acb80650..ff84753eea80 100644 --- a/crates/polars-io/src/csv/write_impl.rs +++ b/crates/polars-io/src/csv/write_impl.rs @@ -9,7 +9,7 @@ use arrow::legacy::time_zone::Tz; use arrow::temporal_conversions; #[cfg(feature = "timezones")] use chrono::TimeZone; -use memchr::{memchr, memchr2}; +use memchr::{memchr3, memmem}; use polars_core::prelude::*; use polars_core::series::SeriesIter; use polars_core::POOL; @@ -20,7 +20,15 @@ use serde::{Deserialize, Serialize}; use super::write::QuoteStyle; -fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> std::io::Result<()> { +const LF: u8 = b'\n'; +const CR: u8 = b'\r'; + +fn fmt_and_escape_str( + f: &mut Vec, + v: &str, + options: &SerializeOptions, + find_quotes: &memmem::Finder, +) -> std::io::Result<()> { if options.quote_style == QuoteStyle::Never { return write!(f, "{v}"); } @@ -28,7 +36,7 @@ fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> s if v.is_empty() { return write!(f, "{quote}{quote}"); } - let needs_escaping = memchr(options.quote_char, v.as_bytes()).is_some(); + let needs_escaping = find_quotes.find(v.as_bytes()).is_some(); if needs_escaping { let replaced = unsafe { // Replace from single quote " to double quote "". @@ -41,7 +49,7 @@ fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> s } let surround_with_quotes = match options.quote_style { QuoteStyle::Always | QuoteStyle::NonNumeric => true, - QuoteStyle::Necessary => memchr2(options.separator, b'\n', v.as_bytes()).is_some(), + QuoteStyle::Necessary => memchr3(options.separator, LF, CR, v.as_bytes()).is_some(), QuoteStyle::Never => false, }; @@ -72,17 +80,18 @@ unsafe fn write_any_value( datetime_formats: &[&str], time_zones: &[Option], i: usize, + find_quotes: &memmem::Finder, ) -> PolarsResult<()> { match value { // First do the string-like types as they know how to deal with quoting. AnyValue::String(v) => { - fmt_and_escape_str(f, v, options)?; + fmt_and_escape_str(f, v, options, find_quotes)?; Ok(()) }, #[cfg(feature = "dtype-categorical")] AnyValue::Categorical(idx, rev_map, _) | AnyValue::Enum(idx, rev_map, _) => { let v = rev_map.get(idx); - fmt_and_escape_str(f, v, options)?; + fmt_and_escape_str(f, v, options, find_quotes)?; Ok(()) }, _ => { @@ -410,6 +419,8 @@ pub(crate) fn write( let last_ptr = &col_iters[col_iters.len() - 1] as *const SeriesIter; let mut finished = false; + let binding = &[options.quote_char]; + let find_quotes = memmem::Finder::new(binding); // loop rows while !finished { for (i, col) in &mut col_iters.iter_mut().enumerate() { @@ -422,6 +433,7 @@ pub(crate) fn write( &datetime_formats, &time_zones, i, + &find_quotes, )?; }, None => { @@ -475,8 +487,10 @@ pub(crate) fn write_header( let mut escaped_names: Vec = Vec::with_capacity(names.len()); let mut nm: Vec = vec![]; + let binding = &[options.quote_char]; + let find_quotes = memmem::Finder::new(binding); for name in names { - fmt_and_escape_str(&mut nm, name, options)?; + fmt_and_escape_str(&mut nm, name, options, &find_quotes)?; unsafe { // SAFETY: we know headers will be valid UTF-8 at this point escaped_names.push(std::str::from_utf8_unchecked(&nm).to_string()); diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 127384441357..4645a0feef92 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1969,3 +1969,11 @@ def test_read_csv_single_column(columns: list[str] | str) -> None: def test_csv_invalid_escape_utf8_14960() -> None: with pytest.raises(pl.ComputeError, match=r"field is not properly escaped"): pl.read_csv('col1\n""•'.encode()) + + +def test_csv_escape_cf_15349() -> None: + f = io.BytesIO() + df = pl.DataFrame({"test": ["normal", "with\rcr"]}) + df.write_csv(f) + f.seek(0) + assert f.read() == b'test\nnormal\n"with\rcr"\n' From ab5c0ee4f963cf82ceefac11bffe9c8bbcffe469 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sun, 31 Mar 2024 15:59:33 +0800 Subject: [PATCH 18/30] fix: `to_any_value` should supports all LiteralValue type (#15387) --- crates/polars-plan/src/logical_plan/lit.rs | 38 ++++++++++++++++++- py-polars/tests/unit/functions/test_repeat.py | 2 + 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index efea5e55b4b8..ac206ba7e9fa 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -116,7 +116,43 @@ impl LiteralValue { DateTime(v, tu, tz) => AnyValue::Datetime(*v, *tu, tz), #[cfg(feature = "dtype-time")] Time(v) => AnyValue::Time(*v), - _ => return None, + Series(s) => AnyValue::List(s.0.clone().into_series()), + Range { + low, + high, + data_type, + } => { + let opt_s = match data_type { + DataType::Int32 => { + if *low < i32::MIN as i64 || *high > i32::MAX as i64 { + return None; + } + + let low = *low as i32; + let high = *high as i32; + new_int_range::(low, high, 1, "range").ok() + }, + DataType::Int64 => { + let low = *low; + let high = *high; + new_int_range::(low, high, 1, "range").ok() + }, + DataType::UInt32 => { + if *low < 0 || *high > u32::MAX as i64 { + return None; + } + let low = *low as u32; + let high = *high as u32; + new_int_range::(low, high, 1, "range").ok() + }, + _ => return None, + }; + match opt_s { + Some(s) => AnyValue::List(s), + None => return None, + } + }, + Binary(v) => AnyValue::Binary(v), }; Some(av) } diff --git a/py-polars/tests/unit/functions/test_repeat.py b/py-polars/tests/unit/functions/test_repeat.py index 4b1d3138b592..b9c37aded947 100644 --- a/py-polars/tests/unit/functions/test_repeat.py +++ b/py-polars/tests/unit/functions/test_repeat.py @@ -28,6 +28,8 @@ (8, 2, pl.UInt8, pl.UInt8), (date(2023, 2, 2), 3, pl.Datetime, pl.Datetime), (7.5, 5, pl.UInt16, pl.UInt16), + ([1, 2, 3], 2, pl.List(pl.Int64), pl.List(pl.Int64)), + (b"ab12", 3, pl.Binary, pl.Binary), ], ) def test_repeat( From 21793740ce5d16baca452d562742a7c5a57cde56 Mon Sep 17 00:00:00 2001 From: Sol-Hee <72716774+Sol-Hee@users.noreply.github.com> Date: Sun, 31 Mar 2024 17:05:34 +0900 Subject: [PATCH 19/30] docs: Add `outer_coalesce` join strategy in the user guide (#15405) --- .../user-guide/transformations/joins.py | 7 +++++ .../rust/user-guide/transformations/joins.rs | 16 +++++++++- docs/user-guide/transformations/joins.md | 29 ++++++++++++++----- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/docs/src/python/user-guide/transformations/joins.py b/docs/src/python/user-guide/transformations/joins.py index 98828020820d..663d68b49517 100644 --- a/docs/src/python/user-guide/transformations/joins.py +++ b/docs/src/python/user-guide/transformations/joins.py @@ -41,6 +41,13 @@ print(df_outer_join) # --8<-- [end:outer] +# --8<-- [start:outer_coalesce] +df_outer_coalesce_join = df_customers.join( + df_orders, on="customer_id", how="outer_coalesce" +) +print(df_outer_coalesce_join) +# --8<-- [end:outer_coalesce] + # --8<-- [start:df3] df_colors = pl.DataFrame( { diff --git a/docs/src/rust/user-guide/transformations/joins.rs b/docs/src/rust/user-guide/transformations/joins.rs index 2fd924a6b1c4..80647431d0a1 100644 --- a/docs/src/rust/user-guide/transformations/joins.rs +++ b/docs/src/rust/user-guide/transformations/joins.rs @@ -58,12 +58,26 @@ fn main() -> Result<(), Box> { df_orders.clone().lazy(), [col("customer_id")], [col("customer_id")], - JoinArgs::new(JoinType::Outer { coalesce: true }), + JoinArgs::new(JoinType::Outer { coalesce: false }), ) .collect()?; println!("{}", &df_outer_join); // --8<-- [end:outer] + // --8<-- [start:outer_coalesce] + let df_outer_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Outer { coalesce: true }), + ) + .collect()?; + println!("{}", &df_outer_join); + // --8<-- [end:outer_coalesce] + // --8<-- [start:df3] let df_colors = df!( "color"=> &["red", "blue", "green"], diff --git a/docs/user-guide/transformations/joins.md b/docs/user-guide/transformations/joins.md index 1a6f29337191..70efcce5f310 100644 --- a/docs/user-guide/transformations/joins.md +++ b/docs/user-guide/transformations/joins.md @@ -4,14 +4,15 @@ Polars supports the following join strategies by specifying the `how` argument: -| Strategy | Description | -| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `inner` | Returns row with matching keys in _both_ frames. Non-matching rows in either the left or right frame are discarded. | -| `left` | Returns all rows in the left dataframe, whether or not a match in the right-frame is found. Non-matching rows have their right columns null-filled. | -| `outer` | Returns all rows from both the left and right dataframe. If no match is found in one frame, columns from the other frame are null-filled. | -| `cross` | Returns the Cartesian product of all rows from the left frame with all rows from the right frame. Duplicates rows are retained; the table length of `A` cross-joined with `B` is always `len(A) × len(B)`. | -| `semi` | Returns all rows from the left frame in which the join key is also present in the right frame. | -| `anti` | Returns all rows from the left frame in which the join key is _not_ present in the right frame. | +| Strategy | Description | +| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `inner` | Returns row with matching keys in _both_ frames. Non-matching rows in either the left or right frame are discarded. | +| `left` | Returns all rows in the left dataframe, whether or not a match in the right-frame is found. Non-matching rows have their right columns null-filled. | +| `outer` | Returns all rows from both the left and right dataframe. If no match is found in one frame, columns from the other frame are null-filled. | +| `outer_coalesce` | Returns all rows from both the left and right dataframe. This is similar to `outer`, but with the key columns being merged. | +| `cross` | Returns the Cartesian product of all rows from the left frame with all rows from the right frame. Duplicates rows are retained; the table length of `A` cross-joined with `B` is always `len(A) × len(B)`. | +| `semi` | Returns all rows from the left frame in which the join key is also present in the right frame. | +| `anti` | Returns all rows from the left frame in which the join key is _not_ present in the right frame. | ### Inner join @@ -62,6 +63,18 @@ The `outer` join produces a `DataFrame` that contains all the rows from both `Da --8<-- "python/user-guide/transformations/joins.py:outer" ``` +### Outer coalesce join + +The `outer_coalesce` join combines all rows from both `DataFrames` like an `outer` join, but it merges the join keys into a single column by coalescing the values. This ensures a unified view of the join key, avoiding nulls in key columns whenever possible. Let's compare it with the outer join using the two `DataFrames` we used above: + +{{code_block('user-guide/transformations/joins','outer_coalesce',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:outer_coalesce" +``` + +In contrast to an `outer` join, where `customer_id` and `customer_id_right` columns would remain separate, the `outer_coalesce` join merges these columns into a single `customer_id` column. + ### Cross join A `cross` join is a Cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. From c477053c55099424974304e088c55d1187987c0c Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 31 Mar 2024 10:07:18 +0200 Subject: [PATCH 20/30] fix: Unset UpdateGroups after group-sensitive expression (#15400) --- .../src/physical_plan/expressions/apply.rs | 1 + py-polars/tests/unit/expr/test_expr_apply_eval.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 py-polars/tests/unit/expr/test_expr_apply_eval.py diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index 1bf465412894..7a447dbade11 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -100,6 +100,7 @@ impl ApplyExpr { ac.with_agg_state(AggState::AggregatedScalar( ca.explode().unwrap().into_series(), )); + ac.with_update_groups(UpdateGroups::No); } else { ac.with_series(ca.into_series(), true, Some(&self.expr))?; ac.with_update_groups(UpdateGroups::WithSeriesLen); diff --git a/py-polars/tests/unit/expr/test_expr_apply_eval.py b/py-polars/tests/unit/expr/test_expr_apply_eval.py new file mode 100644 index 000000000000..9e8fa0881947 --- /dev/null +++ b/py-polars/tests/unit/expr/test_expr_apply_eval.py @@ -0,0 +1,15 @@ +import polars as pl + + +def test_expression_15183() -> None: + assert ( + pl.DataFrame( + {"a": [1, 2, 3, 4, 5, 2, 3, 5, 1], "b": [1, 2, 3, 1, 2, 3, 1, 2, 3]} + ) + .group_by("a") + .agg(pl.col.b.unique().sort().str.concat("-").str.split("-")) + .sort("a") + ).to_dict(as_series=False) == { + "a": [1, 2, 3, 4, 5], + "b": [["1", "3"], ["2", "3"], ["1", "3"], ["1"], ["2"]], + } From 59ff9505eecff2796bafaf7a30bd98db47bc471b Mon Sep 17 00:00:00 2001 From: Marshall Date: Sun, 31 Mar 2024 04:08:19 -0400 Subject: [PATCH 21/30] fix: Return 0 for `n_unique()` in group-by context when group is empty (#15289) --- .../frame/group_by/aggregations/dispatch.rs | 26 ++++++------ .../src/frame/group_by/aggregations/mod.rs | 12 ++++++ py-polars/tests/unit/dataframe/test_df.py | 40 +++++++++++++++++++ .../tests/unit/operations/test_group_by.py | 2 +- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index 82f661dc0752..12ad27ec8fc6 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -81,23 +81,25 @@ impl Series { #[doc(hidden)] pub unsafe fn agg_n_unique(&self, groups: &GroupsProxy) -> Series { match groups { - GroupsProxy::Idx(groups) => agg_helper_idx_on_all::(groups, |idx| { - debug_assert!(idx.len() <= self.len()); - if idx.is_empty() { - None - } else { - let take = self.take_slice_unchecked(idx); - take.n_unique().ok().map(|v| v as IdxSize) - } - }), + GroupsProxy::Idx(groups) => { + agg_helper_idx_on_all_no_null::(groups, |idx| { + debug_assert!(idx.len() <= self.len()); + if idx.is_empty() { + 0 + } else { + let take = self.take_slice_unchecked(idx); + take.n_unique().unwrap() as IdxSize + } + }) + }, GroupsProxy::Slice { groups, .. } => { - _agg_helper_slice::(groups, |[first, len]| { + _agg_helper_slice_no_null::(groups, |[first, len]| { debug_assert!(len <= self.len() as IdxSize); if len == 0 { - None + 0 } else { let take = self.slice_from_offsets(first, len); - take.n_unique().ok().map(|v| v as IdxSize) + take.n_unique().unwrap() as IdxSize } }) }, diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index dc11bdbe402c..76b4339ae0eb 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -192,6 +192,18 @@ where ca.into_series() } +/// Same as `agg_helper_idx_on_all` but for aggregations that don't return an Option. +fn agg_helper_idx_on_all_no_null(groups: &GroupsIdx, f: F) -> Series +where + F: Fn(&IdxVec) -> T::Native + Send + Sync, + T: PolarsNumericType, + ChunkedArray: IntoSeries, +{ + let ca: NoNull> = + POOL.install(|| groups.all().into_par_iter().map(f).collect()); + ca.into_inner().into_series() +} + pub fn _agg_helper_slice(groups: &[[IdxSize; 2]], f: F) -> Series where F: Fn([IdxSize; 2]) -> Option + Send + Sync, diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 2ca9f40a88db..aaf95548a709 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1558,6 +1558,46 @@ def test_group_by_agg_n_unique_floats() -> None: assert out["b"].to_list() == [2, 1] +def test_group_by_agg_n_unique_empty_group_idx_path() -> None: + df = pl.DataFrame( + { + "key": [1, 1, 1, 2, 2, 2], + "value": [1, 2, 3, 4, 5, 6], + "filt": [True, True, True, False, False, False], + } + ) + out = df.group_by("key", maintain_order=True).agg( + pl.col("value").filter("filt").n_unique().alias("n_unique") + ) + expected = pl.DataFrame( + { + "key": [1, 2], + "n_unique": pl.Series([3, 0], dtype=pl.UInt32), + } + ) + assert_frame_equal(out, expected) + + +def test_group_by_agg_n_unique_empty_group_slice_path() -> None: + df = pl.DataFrame( + { + "key": [1, 1, 1, 2, 2, 2], + "value": [1, 2, 3, 4, 5, 6], + "filt": [False, False, False, False, False, False], + } + ) + out = df.group_by("key", maintain_order=True).agg( + pl.col("value").filter("filt").n_unique().alias("n_unique") + ) + expected = pl.DataFrame( + { + "key": [1, 2], + "n_unique": pl.Series([0, 0], dtype=pl.UInt32), + } + ) + assert_frame_equal(out, expected) + + def test_select_by_dtype(df: pl.DataFrame) -> None: out = df.select(pl.col(pl.String)) assert out.columns == ["strings", "strings_nulls"] diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 308598b61235..a53011144ea0 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -675,7 +675,7 @@ def test_group_by_multiple_column_reference() -> None: ("mean", [], [1.0, None], pl.Float64), ("median", [], [1.0, None], pl.Float64), ("min", [], [1, None], pl.Int64), - ("n_unique", [], [1, None], pl.UInt32), + ("n_unique", [], [1, 0], pl.UInt32), ("quantile", [0.5], [1.0, None], pl.Float64), ], ) From cd1994b63e32191640cca80da4fd420af0650378 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 31 Mar 2024 10:39:20 +0200 Subject: [PATCH 22/30] fix: Don't prune alias in function subtree (#15406) --- crates/polars-plan/src/logical_plan/conversion.rs | 8 +------- py-polars/tests/unit/test_schema.py | 5 +++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 3c1f6f65ab3a..c78e8c5b56ea 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -221,13 +221,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta function, options, } => { - match function { - #[cfg(feature = "dtype-struct")] - FunctionExpr::AsStruct => { - state.prune_alias = false; - }, - _ => {}, - } + state.prune_alias = false; AExpr::Function { input: to_aexprs(input, arena, state), function, diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 1950486799a7..41d8c6777aef 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -638,3 +638,8 @@ def test_literal_subtract_schema_13284() -> None: def test_schema_boolean_sum_horizontal() -> None: lf = pl.LazyFrame({"a": [True, False]}).select(pl.sum_horizontal("a")) assert lf.schema == OrderedDict([("a", pl.UInt32)]) + + +def test_struct_alias_prune_15401() -> None: + df = pl.DataFrame({"a": []}, schema={"a": pl.Struct({"b": pl.Int8})}) + assert df.select(pl.col("a").alias("c").struct.field("b")).columns == ["b"] From 1485fec63a2550d06253cc18dc1d5e5e2f68bf94 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 31 Mar 2024 11:18:52 +0100 Subject: [PATCH 23/30] feat(python): make Series.__bool__ error message Rusttier (#15407) --- py-polars/polars/series/series.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 3bd8026651db..a651f8f77508 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -607,8 +607,12 @@ def shape(self) -> tuple[int]: def __bool__(self) -> NoReturn: msg = ( "the truth value of a Series is ambiguous" - "\n\nHint: use '&' or '|' to chain Series boolean results together, not and/or." - " To check if a Series contains any values, use `is_empty()`." + "\n\n" + "Here are some things you might want to try:\n" + "- instead of `if s`, use `if not s.is_empty()`\n" + "- instead of `s1 and s2`, use `s1 & s2`\n" + "- instead of `s1 or s2`, use `s1 | s2`\n" + "- instead of `s in [y, z]`, use `s.is_in([y, z])`\n" ) raise TypeError(msg) From 2b28777966a5667fa827b54d6a12d776bf2ef826 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 31 Mar 2024 12:48:39 +0200 Subject: [PATCH 24/30] fix: Ensure Binary -> Binview cast doesn't overflow the buffer size (#15408) --- .../polars-arrow/src/compute/cast/utf8_to.rs | 106 ++++++++++++++++-- 1 file changed, 99 insertions(+), 7 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index fadb6552beab..0f8892e498aa 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -5,6 +5,7 @@ use polars_utils::slice::GetSaferUnchecked; use polars_utils::vec::PushUnchecked; use crate::array::*; +use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; use crate::offset::Offset; use crate::types::NativeType; @@ -69,14 +70,51 @@ pub fn utf8_to_binary( } } +// Different types to test the overflow path. +#[cfg(not(test))] +type OffsetType = u32; + +// To trigger overflow +#[cfg(test)] +type OffsetType = i8; + +// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple +// chunks so that we don't overflow the offset u32. +fn truncate_buffer(buf: &Buffer) -> Buffer { + // * 2, as it must be able to hold u32::MAX offset + u32::MAX len. + buf.clone() + .sliced(0, std::cmp::min(buf.len(), OffsetType::MAX as usize * 2)) +} + pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { - let buffer_idx = 0_u32; - let base_ptr = arr.values().as_ptr() as usize; + // Ensure we didn't accidentally set wrong type + #[cfg(not(debug_assertions))] + { + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::() + ); + } let mut views = Vec::with_capacity(arr.len()); let mut uses_buffer = false; + + let mut base_buffer = arr.values().clone(); + // Offset into the buffer + let mut base_ptr = base_buffer.as_ptr() as usize; + + // Offset into the binview buffers + let mut buffer_idx = 0_u32; + + // Binview buffers + // Note that the buffer may look far further than u32::MAX, but as we don't clone data + let mut buffers = vec![truncate_buffer(&base_buffer)]; + for bytes in arr.values_iter() { - let len: u32 = bytes.len().try_into().unwrap(); + let len: u32 = bytes + .len() + .try_into() + .expect("max string/binary length exceeded"); let mut payload = [0; 16]; payload[0..4].copy_from_slice(&len.to_le_bytes()); @@ -85,18 +123,42 @@ pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { payload[4..4 + bytes.len()].copy_from_slice(bytes); } else { uses_buffer = true; + + // Copy the parts we know are correct. unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked_release(0..4)) }; - let offset = (bytes.as_ptr() as usize - base_ptr) as u32; payload[0..4].copy_from_slice(&len.to_le_bytes()); - payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); - payload[12..16].copy_from_slice(&offset.to_le_bytes()); + + let current_bytes_ptr = bytes.as_ptr() as usize; + let offset = current_bytes_ptr - base_ptr; + + // Here we check the overflow of the buffer offset. + if let Ok(offset) = OffsetType::try_from(offset) { + #[allow(clippy::unnecessary_cast)] + let offset = offset as u32; + payload[12..16].copy_from_slice(&offset.to_le_bytes()); + payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); + } else { + let len = base_buffer.len() - offset; + + // Set new buffer + base_buffer = base_buffer.clone().sliced(offset, len); + base_ptr = base_buffer.as_ptr() as usize; + + // And add the (truncated) one to the buffers + buffers.push(truncate_buffer(&base_buffer)); + buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded"); + + let offset = 0u32; + payload[12..16].copy_from_slice(&offset.to_le_bytes()); + payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); + } } let value = View::from_le_bytes(payload); unsafe { views.push_unchecked(value) }; } let buffers = if uses_buffer { - Arc::from([arr.values().clone()]) + Arc::from(buffers) } else { Arc::from([]) }; @@ -114,3 +176,33 @@ pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { pub fn utf8_to_utf8view(arr: &Utf8Array) -> Utf8ViewArray { unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn overflowing_utf8_to_binview() { + let values = [ + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "123", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "234", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "324", + ]; + let array = Utf8Array::::from_slice(values); + + let out = utf8_to_utf8view(&array); + // Ensure we hit the multiple buffers part. + assert_eq!(out.buffers().len(), 6); + // Ensure we created a valid binview + let out = out.values_iter().collect::>(); + assert_eq!(out, values); + } +} From 5cdeea256665efb6305fd6c8b5023b59baba8f1b Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 1 Apr 2024 13:17:06 +0200 Subject: [PATCH 25/30] perf: Add non-order preserving variable row-encoding (#15414) --- .../src/chunked_array/logical/struct_/mod.rs | 5 +- .../ops/sort/arg_sort_multiple.rs | 51 ++++++++-- .../src/frame/group_by/into_groups.rs | 1 - crates/polars-core/src/frame/group_by/mod.rs | 6 +- .../executors/sinks/group_by/generic/eval.rs | 4 +- .../src/executors/sinks/sort/sink_multiple.rs | 15 ++- crates/polars-row/src/decode.rs | 6 +- crates/polars-row/src/encode.rs | 92 ++++++++++--------- crates/polars-row/src/fixed.rs | 12 +-- crates/polars-row/src/lib.rs | 2 +- crates/polars-row/src/row.rs | 24 ++++- crates/polars-row/src/variable.rs | 77 +++++++++++++--- 12 files changed, 203 insertions(+), 92 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs index 9457bfe00bd1..b5651cd2a405 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs @@ -9,10 +9,10 @@ use arrow::legacy::trusted_len::TrustedLenPush; use arrow::offset::OffsetsBuffer; use smartstring::alias::String as SmartString; -use self::sort::arg_sort_multiple::_get_rows_encoded_ca; use super::*; use crate::chunked_array::iterator::StructIter; use crate::datatypes::*; +use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered; use crate::utils::index_to_chunked_index; /// This is logical type [`StructChunked`] that @@ -415,8 +415,7 @@ impl StructChunked { } pub fn rows_encode(&self) -> PolarsResult { - let descending = vec![false; self.fields.len()]; - _get_rows_encoded_ca(self.name(), &self.fields, &descending, false) + _get_rows_encoded_ca_unordered(self.name(), &self.fields) } pub fn iter(&self) -> StructIter { diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 4deffaab0165..35e2d57decf3 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,5 +1,5 @@ use compare_inner::NullOrderCmp; -use polars_row::{convert_columns, RowsEncoded, SortField}; +use polars_row::{convert_columns, EncodingField, RowsEncoded}; use polars_utils::iter::EnumerateIdxTrait; use super::*; @@ -87,18 +87,19 @@ pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { Ok(out) } -pub(crate) fn encode_rows_vertical_par_default(by: &[Series]) -> PolarsResult { +pub(crate) fn encode_rows_vertical_par_unordered( + by: &[Series], +) -> PolarsResult { let n_threads = POOL.current_num_threads(); let len = by[0].len(); let splits = _split_offsets(len, n_threads); - let descending = vec![false; by.len()]; let chunks = splits.into_par_iter().map(|(offset, len)| { let sliced = by .iter() .map(|s| s.slice(offset as i64, len)) .collect::>(); - let rows = _get_rows_encoded(&sliced, &descending, false)?; + let rows = _get_rows_encoded_unordered(&sliced)?; Ok(rows.into_array()) }); let chunks = POOL.install(|| chunks.collect::>>()); @@ -106,12 +107,35 @@ pub(crate) fn encode_rows_vertical_par_default(by: &[Series]) -> PolarsResult PolarsResult { - let descending = vec![false; by.len()]; - let rows = _get_rows_encoded(by, &descending, false)?; +pub(crate) fn encode_rows_unordered(by: &[Series]) -> PolarsResult { + let rows = _get_rows_encoded_unordered(by)?; Ok(BinaryOffsetChunked::with_chunk("", rows.into_array())) } +pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { + let mut cols = Vec::with_capacity(by.len()); + let mut fields = Vec::with_capacity(by.len()); + for by in by { + let arr = _get_rows_encoded_compat_array(by)?; + let field = EncodingField::new_unsorted(); + match arr.data_type() { + // Flatten the struct fields. + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + for arr in arr.values() { + cols.push(arr.clone() as ArrayRef); + fields.push(field) + } + }, + _ => { + cols.push(arr); + fields.push(field) + }, + } + } + Ok(convert_columns(&cols, &fields)) +} + pub fn _get_rows_encoded( by: &[Series], descending: &[bool], @@ -123,9 +147,10 @@ pub fn _get_rows_encoded( for (by, descending) in by.iter().zip(descending) { let arr = _get_rows_encoded_compat_array(by)?; - let sort_field = SortField { + let sort_field = EncodingField { descending: *descending, nulls_last, + no_order: false, }; match arr.data_type() { // Flatten the struct fields. @@ -133,7 +158,7 @@ pub fn _get_rows_encoded( let arr = arr.as_any().downcast_ref::().unwrap(); for arr in arr.values() { cols.push(arr.clone() as ArrayRef); - fields.push(sort_field.clone()) + fields.push(sort_field) } }, _ => { @@ -155,6 +180,14 @@ pub fn _get_rows_encoded_ca( .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) } +pub fn _get_rows_encoded_ca_unordered( + name: &str, + by: &[Series], +) -> PolarsResult { + _get_rows_encoded_unordered(by) + .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) +} + pub(crate) fn argsort_multiple_row_fmt( by: &[Series], mut descending: Vec, diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index fe2fd5a493e5..f12779c52819 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -299,7 +299,6 @@ impl IntoGroupsProxy for BinaryChunked { }) .collect::>() }); - let byte_hashes = byte_hashes.iter().collect::>(); group_by_threaded_slice(byte_hashes, n_partitions, sorted) } else { let byte_hashes = fill_bytes_hashes(self, null_h, hb.clone()); diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index bc7e90406cf7..014ae2c8c28d 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -23,7 +23,7 @@ pub use into_groups::*; pub use proxy::*; use crate::prelude::sort::arg_sort_multiple::{ - encode_rows_default, encode_rows_vertical_par_default, + encode_rows_unordered, encode_rows_vertical_par_unordered, }; impl DataFrame { @@ -84,9 +84,9 @@ impl DataFrame { }) } else { let rows = if multithreaded { - encode_rows_vertical_par_default(&by) + encode_rows_vertical_par_unordered(&by) } else { - encode_rows_default(&by) + encode_rows_unordered(&by) }? .into_series(); rows.group_tuples(multithreaded, sorted) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs index 3fa0b384dd0c..c2b4262143da 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs @@ -1,7 +1,7 @@ use std::cell::UnsafeCell; use polars_core::export::ahash::RandomState; -use polars_row::{RowsEncoded, SortField}; +use polars_row::{EncodingField, RowsEncoded}; use super::*; use crate::executors::sinks::group_by::utils::prepare_key; @@ -18,7 +18,7 @@ pub(super) struct Eval { aggregation_series: UnsafeCell>, keys_columns: UnsafeCell>, hashes: Vec, - key_fields: Vec, + key_fields: Vec, // amortizes the encoding buffers rows_encoded: RowsEncoded, } diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 6fdede156fcf..d31d6e77e3a8 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -7,7 +7,7 @@ use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_plan::prelude::*; use polars_row::decode::decode_rows_from_binary; -use polars_row::SortField; +use polars_row::EncodingField; use super::*; use crate::operators::{ @@ -15,15 +15,12 @@ use crate::operators::{ }; const POLARS_SORT_COLUMN: &str = "__POLARS_SORT_COLUMN"; -fn get_sort_fields(sort_idx: &[usize], sort_args: &SortArguments) -> Vec { +fn get_sort_fields(sort_idx: &[usize], sort_args: &SortArguments) -> Vec { let mut descending = sort_args.descending.clone(); _broadcast_descending(sort_idx.len(), &mut descending); descending .into_iter() - .map(|descending| SortField { - descending, - nulls_last: sort_args.nulls_last, - }) + .map(|descending| EncodingField::new_sorted(descending, sort_args.nulls_last)) .collect() } @@ -61,7 +58,7 @@ fn finalize_dataframe( can_decode: bool, sort_dtypes: Option<&[ArrowDataType]>, rows: &mut Vec<&'static [u8]>, - sort_fields: &[SortField], + sort_fields: &[EncodingField], schema: &Schema, ) { unsafe { @@ -126,7 +123,7 @@ pub struct SortSinkMultiple { sort_sink: Box, sort_args: SortArguments, // Needed for encoding - sort_fields: Arc<[SortField]>, + sort_fields: Arc<[EncodingField]>, sort_dtypes: Option>, // amortize allocs sort_column: Vec, @@ -320,7 +317,7 @@ struct DropEncoded { can_decode: bool, sort_dtypes: Option>, rows: Vec<&'static [u8]>, - sort_fields: Arc<[SortField]>, + sort_fields: Arc<[EncodingField]>, output_schema: SchemaRef, } diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index 246ac976fc10..180cf2ad00e8 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -10,7 +10,7 @@ use crate::variable::{decode_binary, decode_binview}; /// encodings. pub unsafe fn decode_rows_from_binary<'a>( arr: &'a BinaryArray, - fields: &[SortField], + fields: &[EncodingField], data_types: &[ArrowDataType], rows: &mut Vec<&'a [u8]>, ) -> Vec { @@ -27,7 +27,7 @@ pub unsafe fn decode_rows_from_binary<'a>( pub unsafe fn decode_rows( // the rows will be updated while the data is decoded rows: &mut [&[u8]], - fields: &[SortField], + fields: &[EncodingField], data_types: &[ArrowDataType], ) -> Vec { assert_eq!(fields.len(), data_types.len()); @@ -38,7 +38,7 @@ pub unsafe fn decode_rows( .collect() } -unsafe fn decode(rows: &mut [&[u8]], field: &SortField, data_type: &ArrowDataType) -> ArrayRef { +unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, data_type: &ArrowDataType) -> ArrayRef { match data_type { ArrowDataType::Null => NullArray::new(ArrowDataType::Null, rows.len()).to_boxed(), ArrowDataType::Boolean => decode_bool(rows, field).to_boxed(), diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index bd94f33a203c..0cd660b47b9e 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -10,10 +10,10 @@ use polars_utils::slice::GetSaferUnchecked; use polars_utils::vec::PushUnchecked; use crate::fixed::FixedLengthEncoding; -use crate::row::{RowsEncoded, SortField}; +use crate::row::{EncodingField, RowsEncoded}; use crate::{with_match_arrow_primitive_type, ArrayRef}; -pub fn convert_columns(columns: &[ArrayRef], fields: &[SortField]) -> RowsEncoded { +pub fn convert_columns(columns: &[ArrayRef], fields: &[EncodingField]) -> RowsEncoded { let mut rows = RowsEncoded::new(vec![], vec![]); convert_columns_amortized(columns, fields, &mut rows); rows @@ -28,7 +28,7 @@ pub fn convert_columns_no_order(columns: &[ArrayRef]) -> RowsEncoded { pub fn convert_columns_amortized_no_order(columns: &[ArrayRef], rows: &mut RowsEncoded) { convert_columns_amortized( columns, - std::iter::repeat(&SortField::default()).take(columns.len()), + std::iter::repeat(&EncodingField::default()).take(columns.len()), rows, ); } @@ -41,7 +41,7 @@ enum Encoder { enc: Vec, rows: Option, original: LargeListArray, - field: SortField, + field: EncodingField, }, Leaf(ArrayRef), } @@ -112,7 +112,7 @@ impl Encoder { } } -fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &SortField) -> usize { +fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &EncodingField) -> usize { let mut added = 0; match arr.data_type() { ArrowDataType::Struct(_) => { @@ -134,7 +134,7 @@ fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &SortField) enc: inner, original: arr.clone(), rows: None, - field: field.clone(), + field: *field, }); added += 1; }, @@ -146,7 +146,7 @@ fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &SortField) added } -pub fn convert_columns_amortized<'a, I: IntoIterator>( +pub fn convert_columns_amortized<'a, I: IntoIterator>( columns: &'a [ArrayRef], fields: I, rows: &mut RowsEncoded, @@ -165,11 +165,15 @@ pub fn convert_columns_amortized<'a, I: IntoIterator>( for (arr, field) in columns.iter().zip(fields) { let added = get_encoders(arr.as_ref(), &mut flattened_columns, field); for _ in 0..added { - flattened_fields.push(field.clone()); + flattened_fields.push(*field); } } - let values_size = - allocate_rows_buf(&mut flattened_columns, &mut rows.values, &mut rows.offsets); + let values_size = allocate_rows_buf( + &mut flattened_columns, + &flattened_fields, + &mut rows.values, + &mut rows.offsets, + ); for (arr, field) in flattened_columns.iter().zip(flattened_fields.iter()) { // SAFETY: // we allocated rows with enough bytes. @@ -182,11 +186,13 @@ pub fn convert_columns_amortized<'a, I: IntoIterator>( .iter() .map(|arr| Encoder::Leaf(arr.clone())) .collect::>(); - let values_size = allocate_rows_buf(&mut encoders, &mut rows.values, &mut rows.offsets); + let fields = fields.cloned().collect::>(); + let values_size = + allocate_rows_buf(&mut encoders, &fields, &mut rows.values, &mut rows.offsets); for (enc, field) in encoders.iter().zip(fields) { // SAFETY: // we allocated rows with enough bytes. - unsafe { encode_array(enc, field, rows) } + unsafe { encode_array(enc, &field, rows) } } // SAFETY: values are initialized unsafe { rows.values.set_len(values_size) } @@ -195,7 +201,7 @@ pub fn convert_columns_amortized<'a, I: IntoIterator>( fn encode_primitive( arr: &PrimitiveArray, - field: &SortField, + field: &EncodingField, out: &mut RowsEncoded, ) { if arr.null_count() == 0 { @@ -211,11 +217,11 @@ fn encode_primitive( /// /// # Safety /// `out` must have enough bytes allocated otherwise it will be out of bounds. -unsafe fn encode_array(encoder: &Encoder, field: &SortField, out: &mut RowsEncoded) { +unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsEncoded) { match encoder { Encoder::List { .. } => { let iter = encoder.list_iter(); - crate::variable::encode_iter(iter, out, &Default::default()) + crate::variable::encode_iter(iter, out, &EncodingField::new_unsorted()) }, Encoder::Leaf(array) => { match array.data_type() { @@ -279,6 +285,7 @@ pub fn encoded_size(data_type: &ArrowDataType) -> usize { // are initialized. fn allocate_rows_buf( columns: &mut [Encoder], + fields: &[EncodingField], values: &mut Vec, offsets: &mut Vec, ) -> usize { @@ -307,7 +314,7 @@ fn allocate_rows_buf( // for the variable length columns we must iterate to determine the length per row location let mut processed_count = 0; - for enc in columns.iter_mut() { + for (enc, enc_field) in columns.iter_mut().zip(fields) { match enc { Encoder::List { enc: inner_enc, @@ -315,6 +322,8 @@ fn allocate_rows_buf( field, original, } => { + let field = *field; + let fields = inner_enc.iter().map(|_| field).collect::>(); // Nested lists don't yet work as that requires the leaves not only allocating, but also // encoding. To make that work we must add a flag `in_list` that tell the leaves to immediately // encode the rows instead of only setting the length. @@ -332,6 +341,7 @@ fn allocate_rows_buf( // Allocate and immediately row-encode the inner types recursively. let values_size = allocate_rows_buf( inner_enc, + &fields, &mut values_rows.values, &mut values_rows.offsets, ); @@ -339,7 +349,7 @@ fn allocate_rows_buf( // For single nested it does work as we encode here. unsafe { for enc in inner_enc { - encode_array(enc, field, &mut values_rows) + encode_array(enc, &field, &mut values_rows) } values_rows.values.set_len(values_size) }; @@ -352,13 +362,20 @@ fn allocate_rows_buf( for opt_val in iter { unsafe { lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val), + row_size_fixed + + crate::variable::encoded_len( + opt_val, + &EncodingField::new_unsorted(), + ), ); } } } else { for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val) + *row_length += crate::variable::encoded_len( + opt_val, + &EncodingField::new_unsorted(), + ) } } processed_count += 1; @@ -371,7 +388,8 @@ fn allocate_rows_buf( for opt_val in array.into_iter() { unsafe { lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val), + row_size_fixed + + crate::variable::encoded_len(opt_val, enc_field), ); } } @@ -379,7 +397,7 @@ fn allocate_rows_buf( for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val) + *row_length += crate::variable::encoded_len(opt_val, enc_field) } } processed_count += 1; @@ -390,7 +408,8 @@ fn allocate_rows_buf( for opt_val in array.into_iter() { unsafe { lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val), + row_size_fixed + + crate::variable::encoded_len(opt_val, enc_field), ); } } @@ -398,7 +417,7 @@ fn allocate_rows_buf( for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val) + *row_length += crate::variable::encoded_len(opt_val, enc_field) } } processed_count += 1; @@ -416,13 +435,14 @@ fn allocate_rows_buf( for opt_val in iter { unsafe { lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val), + row_size_fixed + + crate::variable::encoded_len(opt_val, enc_field), ) } } } else { for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val) + *row_length += crate::variable::encoded_len(opt_val, enc_field) } } processed_count += 1; @@ -514,10 +534,7 @@ mod test { let arr = BinaryViewArray::from_slice([Some("a"), Some(""), Some("meep"), Some(sentence), None]); - let field = SortField { - descending: false, - nulls_last: false, - }; + let field = EncodingField::new_sorted(false, false); let arr = arrow::compute::cast::cast(&arr, &ArrowDataType::BinaryView, Default::default()) .unwrap(); let rows_encoded = convert_columns(&[arr], &[field]); @@ -567,10 +584,7 @@ mod test { let a = [val.as_str(), val.as_str(), val.as_str()]; - let field = SortField { - descending: false, - nulls_last: false, - }; + let field = EncodingField::new_sorted(false, false); let arr = BinaryViewArray::from_slice_values(a); let rows_encoded = convert_columns_no_order(&[arr.clone().boxed()]); @@ -583,10 +597,7 @@ mod test { fn test_reverse_variable() { let a = Utf8ViewArray::from_slice_values(["one", "two", "three", "four", "five", "six"]); - let fields = &[SortField { - descending: true, - nulls_last: false, - }]; + let fields = &[EncodingField::new_sorted(true, false)]; let dtypes = [ArrowDataType::Utf8View]; @@ -614,16 +625,13 @@ mod test { values.boxed(), None, ); - let fields = &[SortField { - descending: true, - nulls_last: false, - }]; + let fields = &[EncodingField::new_sorted(true, false)]; let out = convert_columns(&[array.boxed()], fields); let out = out.into_array(); assert_eq!( out.values().iter().map(|v| *v as usize).sum::(), - 84981 + 82411 ); } } diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index dfc9d6ff94f6..f9bdc4394b08 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -8,7 +8,7 @@ use arrow::types::NativeType; use polars_utils::slice::*; use polars_utils::total_ord::{canonical_f32, canonical_f64}; -use crate::row::{RowsEncoded, SortField}; +use crate::row::{EncodingField, RowsEncoded}; pub(crate) trait FromSlice { fn from_slice(slice: &[u8]) -> Self; @@ -168,7 +168,7 @@ fn encode_value( pub(crate) unsafe fn encode_slice( input: &[T], out: &mut RowsEncoded, - field: &SortField, + field: &EncodingField, ) { out.values.set_len(0); let values = out.values.spare_capacity_mut(); @@ -178,7 +178,7 @@ pub(crate) unsafe fn encode_slice( } #[inline] -pub(crate) fn get_null_sentinel(field: &SortField) -> u8 { +pub(crate) fn get_null_sentinel(field: &EncodingField) -> u8 { if field.nulls_last { 0xFF } else { @@ -189,7 +189,7 @@ pub(crate) fn get_null_sentinel(field: &SortField) -> u8 { pub(crate) unsafe fn encode_iter>, T: FixedLengthEncoding>( input: I, out: &mut RowsEncoded, - field: &SortField, + field: &EncodingField, ) { out.values.set_len(0); let values = out.values.spare_capacity_mut(); @@ -214,7 +214,7 @@ pub(crate) unsafe fn encode_iter>, T: FixedLengthEn pub(super) unsafe fn decode_primitive( rows: &mut [&[u8]], - field: &SortField, + field: &EncodingField, ) -> PrimitiveArray where T::Encoded: FromSlice, @@ -255,7 +255,7 @@ where PrimitiveArray::new(data_type, values.into(), validity) } -pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &SortField) -> BooleanArray { +pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &EncodingField) -> BooleanArray { let mut has_nulls = false; let null_sentinel = get_null_sentinel(field); diff --git a/crates/polars-row/src/lib.rs b/crates/polars-row/src/lib.rs index 2de299ec715c..823e5c6e4566 100644 --- a/crates/polars-row/src/lib.rs +++ b/crates/polars-row/src/lib.rs @@ -279,4 +279,4 @@ pub use encode::{ convert_columns, convert_columns_amortized, convert_columns_amortized_no_order, convert_columns_no_order, }; -pub use row::{RowsEncoded, SortField}; +pub use row::{EncodingField, RowsEncoded}; diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index 26ba4715d33b..d48f6f51c205 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -4,12 +4,32 @@ use arrow::datatypes::ArrowDataType; use arrow::ffi::mmap; use arrow::offset::{Offsets, OffsetsBuffer}; -#[derive(Clone, Default)] -pub struct SortField { +#[derive(Clone, Default, Copy)] +pub struct EncodingField { /// Whether to sort in descending order pub descending: bool, /// Whether to sort nulls first pub nulls_last: bool, + /// Ignore all order-related flags and don't encode order-preserving. + /// This is faster for variable encoding as we can just memcopy all the bytes. + pub no_order: bool, +} + +impl EncodingField { + pub fn new_sorted(descending: bool, nulls_last: bool) -> Self { + EncodingField { + descending, + nulls_last, + no_order: false, + } + } + + pub fn new_unsorted() -> Self { + EncodingField { + no_order: true, + ..Default::default() + } + } } #[derive(Default, Clone)] diff --git a/crates/polars-row/src/variable.rs b/crates/polars-row/src/variable.rs index 4a582afb5e7f..5032e41085a8 100644 --- a/crates/polars-row/src/variable.rs +++ b/crates/polars-row/src/variable.rs @@ -19,7 +19,7 @@ use polars_utils::slice::{GetSaferUnchecked, Slice2Uninit}; use crate::fixed::{decode_nulls, get_null_sentinel}; use crate::row::RowsEncoded; -use crate::SortField; +use crate::EncodingField; /// The block size of the variable length encoding pub(crate) const BLOCK_SIZE: usize = 32; @@ -56,8 +56,54 @@ fn padded_length_opt(a: Option) -> usize { } #[inline] -pub fn encoded_len(a: Option<&[u8]>) -> usize { - padded_length_opt(a.map(|v| v.len())) +fn length_opt(a: Option) -> usize { + if let Some(a) = a { + 1 + a + } else { + 1 + } +} + +#[inline] +pub fn encoded_len(a: Option<&[u8]>, field: &EncodingField) -> usize { + if field.no_order { + length_opt(a.map(|v| v.len())) + } else { + padded_length_opt(a.map(|v| v.len())) + } +} + +unsafe fn encode_one_no_order( + out: &mut [MaybeUninit], + val: Option<&[MaybeUninit]>, + field: &EncodingField, +) -> usize { + match val { + Some([]) => { + let byte = if field.descending { + !EMPTY_SENTINEL + } else { + EMPTY_SENTINEL + }; + *out.get_unchecked_release_mut(0) = MaybeUninit::new(byte); + 1 + }, + Some(val) => { + let end_offset = 1 + val.len(); + + // Write `2_u8` to demarcate as non-empty, non-null string + *out.get_unchecked_release_mut(0) = MaybeUninit::new(NON_EMPTY_SENTINEL); + std::ptr::copy_nonoverlapping(val.as_ptr(), out.as_mut_ptr().add(1), val.len()); + + end_offset + }, + None => { + *out.get_unchecked_release_mut(0) = MaybeUninit::new(get_null_sentinel(field)); + // // write remainder as zeros + // out.get_unchecked_release_mut(1..).fill(MaybeUninit::new(0)); + 1 + }, + } } /// Encode one strings/bytes object and return the written length. @@ -67,7 +113,7 @@ pub fn encoded_len(a: Option<&[u8]>) -> usize { unsafe fn encode_one( out: &mut [MaybeUninit], val: Option<&[MaybeUninit]>, - field: &SortField, + field: &EncodingField, ) -> usize { match val { Some([]) => { @@ -150,14 +196,23 @@ unsafe fn encode_one( pub(crate) unsafe fn encode_iter<'a, I: Iterator>>( input: I, out: &mut RowsEncoded, - field: &SortField, + field: &EncodingField, ) { out.values.set_len(0); let values = out.values.spare_capacity_mut(); - for (offset, opt_value) in out.offsets.iter_mut().skip(1).zip(input) { - let dst = values.get_unchecked_release_mut(*offset..); - let written_len = encode_one(dst, opt_value.map(|v| v.as_uninit()), field); - *offset += written_len; + + if field.no_order { + for (offset, opt_value) in out.offsets.iter_mut().skip(1).zip(input) { + let dst = values.get_unchecked_release_mut(*offset..); + let written_len = encode_one_no_order(dst, opt_value.map(|v| v.as_uninit()), field); + *offset += written_len; + } + } else { + for (offset, opt_value) in out.offsets.iter_mut().skip(1).zip(input) { + let dst = values.get_unchecked_release_mut(*offset..); + let written_len = encode_one(dst, opt_value.map(|v| v.as_uninit()), field); + *offset += written_len; + } } let offset = out.offsets.last().unwrap(); let dst = values.get_unchecked_release_mut(*offset..); @@ -203,7 +258,7 @@ unsafe fn decoded_len( } } -pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &SortField) -> BinaryArray { +pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &EncodingField) -> BinaryArray { let (non_empty_sentinel, continuation_token) = if field.descending { (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION_TOKEN) } else { @@ -274,7 +329,7 @@ pub(super) unsafe fn decode_binary(rows: &mut [&[u8]], field: &SortField) -> Bin ) } -pub(super) unsafe fn decode_binview(rows: &mut [&[u8]], field: &SortField) -> BinaryViewArray { +pub(super) unsafe fn decode_binview(rows: &mut [&[u8]], field: &EncodingField) -> BinaryViewArray { let (non_empty_sentinel, continuation_token) = if field.descending { (!NON_EMPTY_SENTINEL, !BLOCK_CONTINUATION_TOKEN) } else { From 2c9b6c14a424040cf708250ecf40531311c70b6e Mon Sep 17 00:00:00 2001 From: JamesCE2001 <45176743+JamesCE2001@users.noreply.github.com> Date: Mon, 1 Apr 2024 07:22:14 -0400 Subject: [PATCH 26/30] feat(rust, python): Add `null_on_oob` parameter to `expr.list.get` (#15395) Co-authored-by: James Edwards --- .../polars-arrow/src/legacy/kernels/list.rs | 7 ++ .../src/chunked_array/list/namespace.rs | 8 +- .../src/chunked_array/list/to_struct.rs | 2 +- .../polars-plan/src/dsl/function_expr/list.rs | 33 ++++--- .../polars-plan/src/dsl/function_expr/mod.rs | 8 ++ crates/polars-plan/src/dsl/list.rs | 8 +- crates/polars-sql/src/functions.rs | 2 +- py-polars/polars/expr/list.py | 13 ++- py-polars/polars/series/list.py | 13 ++- py-polars/src/expr/list.rs | 8 +- py-polars/tests/unit/datatypes/test_list.py | 2 +- .../tests/unit/namespaces/list/test_list.py | 99 ++++++++++++++++--- 12 files changed, 159 insertions(+), 44 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index e67d1638e99d..46c339323b1b 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -75,6 +75,13 @@ pub fn sublist_get(arr: &ListArray, index: i64) -> ArrayRef { unsafe { take_unchecked(&**values, &take_by) } } +/// Check if an index is out of bounds for at least one sublist. +pub fn index_is_oob(arr: &ListArray, index: i64) -> bool { + arr.offsets() + .lengths() + .any(|len| index.negative_to_usize(len).is_none()) +} + /// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]` pub fn array_to_unit_list(array: ArrayRef) -> ListArray { let len = array.len(); diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index a4f7e78e2c6d..0d511c87967c 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -1,7 +1,7 @@ use std::fmt::Write; use arrow::array::ValueSize; -use arrow::legacy::kernels::list::sublist_get; +use arrow::legacy::kernels::list::{index_is_oob, sublist_get}; use polars_core::chunked_array::builder::get_list_builder; #[cfg(feature = "list_gather")] use polars_core::export::num::ToPrimitive; @@ -341,8 +341,12 @@ pub trait ListNameSpaceImpl: AsList { /// So index `0` would return the first item of every sublist /// and index `-1` would return the last item of every sublist /// if an index is out of bounds, it will return a `None`. - fn lst_get(&self, idx: i64) -> PolarsResult { + fn lst_get(&self, idx: i64, null_on_oob: bool) -> PolarsResult { let ca = self.as_list(); + if !null_on_oob && ca.downcast_iter().any(|arr| index_is_oob(arr, idx)) { + polars_bail!(ComputeError: "get index is out of bounds"); + } + let chunks = ca .downcast_iter() .map(|arr| sublist_get(arr, idx)) diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index c43cfda13024..4b74a76692ed 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -72,7 +72,7 @@ pub trait ToStruct: AsList { (0..n_fields) .into_par_iter() .map(|i| { - ca.lst_get(i as i64).map(|mut s| { + ca.lst_get(i as i64, true).map(|mut s| { s.rename(&name_generator(i)); s }) diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 3fdbf6a18134..3b06841fbb55 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -21,7 +21,7 @@ pub enum ListFunction { }, Slice, Shift, - Get, + Get(bool), #[cfg(feature = "list_gather")] Gather(bool), #[cfg(feature = "list_gather")] @@ -71,7 +71,7 @@ impl ListFunction { Sample { .. } => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), Shift => mapper.with_same_dtype(), - Get => mapper.map_to_list_and_array_inner_dtype(), + Get(_) => mapper.map_to_list_and_array_inner_dtype(), #[cfg(feature = "list_gather")] Gather(_) => mapper.with_same_dtype(), #[cfg(feature = "list_gather")] @@ -136,7 +136,7 @@ impl Display for ListFunction { }, Slice => "slice", Shift => "shift", - Get => "get", + Get(_) => "get", #[cfg(feature = "list_gather")] Gather(_) => "gather", #[cfg(feature = "list_gather")] @@ -203,9 +203,9 @@ impl From for SpecialEq> { }, Slice => wrap!(slice), Shift => map_as_slice!(shift), - Get => wrap!(get), + Get(null_on_oob) => wrap!(get, null_on_oob), #[cfg(feature = "list_gather")] - Gather(null_ob_oob) => map_as_slice!(gather, null_ob_oob), + Gather(null_on_oob) => map_as_slice!(gather, null_on_oob), #[cfg(feature = "list_gather")] GatherEvery => map_as_slice!(gather_every), #[cfg(feature = "list_count")] @@ -414,7 +414,7 @@ pub(super) fn concat(s: &mut [Series]) -> PolarsResult> { first_ca.lst_concat(other).map(|ca| Some(ca.into_series())) } -pub(super) fn get(s: &mut [Series]) -> PolarsResult> { +pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult> { let ca = s[0].list()?; let index = s[1].cast(&DataType::Int64)?; let index = index.i64().unwrap(); @@ -423,7 +423,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { 1 => { let index = index.get(0); if let Some(index) = index { - ca.lst_get(index).map(Some) + ca.lst_get(index, null_on_oob).map(Some) } else { Ok(Some(Series::full_null( ca.name(), @@ -440,19 +440,24 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { let take_by = index .into_iter() .enumerate() - .map(|(i, opt_idx)| { - opt_idx.and_then(|idx| { + .map(|(i, opt_idx)| match opt_idx { + Some(idx) => { let (start, end) = unsafe { (*offsets.get_unchecked(i), *offsets.get_unchecked(i + 1)) }; let offset = if idx >= 0 { start + idx } else { end + idx }; if offset >= end || offset < start || start == end { - None + if null_on_oob { + Ok(None) + } else { + polars_bail!(ComputeError: "get index is out of bounds"); + } } else { - Some(offset as IdxSize) + Ok(Some(offset as IdxSize)) } - }) + }, + None => Ok(None), }) - .collect::(); + .collect::>()?; let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); unsafe { s.take_unchecked(&take_by) } .cast(&ca.inner_dtype()) @@ -475,7 +480,7 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult if idx.len() == 1 && null_on_oob { // fast path let idx = idx.get(0)?.try_extract::()?; - let out = ca.lst_get(idx)?; + let out = ca.lst_get(idx, null_on_oob)?; // make sure we return a list out.reshape(&[-1, 1]) } else { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 397959e9980f..5ff04aee0098 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -718,6 +718,14 @@ macro_rules! wrap { ($e:expr) => { SpecialEq::new(Arc::new($e)) }; + + ($e:expr, $($args:expr),*) => {{ + let f = move |s: &mut [Series]| { + $e(s, $($args),*) + }; + + SpecialEq::new(Arc::new(f)) + }}; } // Fn(&[Series], args) diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 603ec2553590..0f6c15c755e7 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -151,9 +151,9 @@ impl ListNameSpace { } /// Get items in every sublist by index. - pub fn get(self, index: Expr) -> Expr { + pub fn get(self, index: Expr, null_on_oob: bool) -> Expr { self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::Get), + FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)), &[index], false, false, @@ -187,12 +187,12 @@ impl ListNameSpace { /// Get first item of every sublist. pub fn first(self) -> Expr { - self.get(lit(0i64)) + self.get(lit(0i64), true) } /// Get last item of every sublist. pub fn last(self) -> Expr { - self.get(lit(-1i64)) + self.get(lit(-1i64), true) } /// Join all string items in a sublist and place a separator between them. diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 2149912be665..6cfd5263c416 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -987,7 +987,7 @@ impl SQLFunctionVisitor<'_> { // Array functions // ---- ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), - ArrayGet => self.visit_binary(|e, i| e.list().get(i)), + ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)), ArrayLength => self.visit_unary(|e| e.list().len()), ArrayMax => self.visit_unary(|e| e.list().max()), ArrayMean => self.visit_unary(|e| e.list().mean()), diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 474e586a1c62..3c827794ffdb 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -505,7 +505,12 @@ def concat(self, other: list[Expr | str] | Expr | str | Series | list[Any]) -> E other_list.insert(0, wrap_expr(self._pyexpr)) return F.concat_list(other_list) - def get(self, index: int | Expr | str) -> Expr: + def get( + self, + index: int | Expr | str, + *, + null_on_oob: bool = True, + ) -> Expr: """ Get the value by index in the sublists. @@ -517,6 +522,10 @@ def get(self, index: int | Expr | str) -> Expr: ---------- index Index to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error Examples -------- @@ -534,7 +543,7 @@ def get(self, index: int | Expr | str) -> Expr: └───────────┴──────┘ """ index = parse_as_expression(index) - return wrap_expr(self._pyexpr.list_get(index)) + return wrap_expr(self._pyexpr.list_get(index, null_on_oob)) def gather( self, diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 9540cc6a2860..7879d96c9ea2 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -373,7 +373,12 @@ def concat(self, other: list[Series] | Series | list[Any]) -> Series: ] """ - def get(self, index: int | Series | list[int]) -> Series: + def get( + self, + index: int | Series | list[int], + *, + null_on_oob: bool = True, + ) -> Series: """ Get the value by index in the sublists. @@ -385,11 +390,15 @@ def get(self, index: int | Series | list[int]) -> Series: ---------- index Index to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error Examples -------- >>> s = pl.Series("a", [[3, 2, 1], [], [1, 2]]) - >>> s.list.get(0) + >>> s.list.get(0, null_on_oob=True) shape: (3,) Series: 'a' [i64] [ diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index fde544a6ce41..b00476c7bb3a 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -44,8 +44,12 @@ impl PyExpr { self.inner.clone().list().eval(expr.inner, parallel).into() } - fn list_get(&self, index: PyExpr) -> Self { - self.inner.clone().list().get(index.inner).into() + fn list_get(&self, index: PyExpr, null_on_oob: bool) -> Self { + self.inner + .clone() + .list() + .get(index.inner, null_on_oob) + .into() } fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index ba580ea8e6ba..fe0028d5067c 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -781,7 +781,7 @@ def test_list_gather_null_struct_14927() -> None: {"index": [1], "col_0": [None], "field_0": [None]}, schema={**df.schema, "field_0": pl.Float64}, ) - expr = pl.col("col_0").list.get(0).struct.field("field_0") + expr = pl.col("col_0").list.get(0, null_on_oob=True).struct.field("field_0") out = df.filter(pl.col("index") > 0).with_columns(expr) assert_frame_equal(out, expected) diff --git a/py-polars/tests/unit/namespaces/list/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py index 570716e14fe5..40dc3561598c 100644 --- a/py-polars/tests/unit/namespaces/list/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -11,7 +11,7 @@ def test_list_arr_get() -> None: a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(0) + out = a.list.get(0, null_on_oob=False) expected = pl.Series("a", [1, 4, 6]) assert_series_equal(out, expected) out = a.list[0] @@ -22,7 +22,74 @@ def test_list_arr_get() -> None: out = pl.select(pl.lit(a).list.first()).to_series() assert_series_equal(out, expected) - out = a.list.get(-1) + out = a.list.get(-1, null_on_oob=False) + expected = pl.Series("a", [3, 5, 9]) + assert_series_equal(out, expected) + out = a.list.last() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.last()).to_series() + assert_series_equal(out, expected) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + a.list.get(3, null_on_oob=False) + + # Null index. + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=False)) + expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() + assert_frame_equal(out_df, expected_df) + + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + a.list.get(-3, null_on_oob=False) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + pl.DataFrame( + {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} + ).with_columns( + [ + pl.col("a").list.get(i, null_on_oob=False).alias(f"get_{i}") + for i in range(4) + ] + ) + + # get by indexes where some are out of bounds + df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select([pl.col("cars").list.get("indexes", null_on_oob=False)]).to_dict( + as_series=False + ) + + # exact on oob boundary + df = pl.DataFrame( + { + "index": [3, 3, 3], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(3, null_on_oob=False)) + + with pytest.raises(pl.ComputeError, match="get index is out of bounds"): + df.select(pl.col("lists").list.get(pl.col("index"), null_on_oob=False)) + + +def test_list_arr_get_null_on_oob() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + out = a.list.get(0, null_on_oob=True) + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list[0] + expected = pl.Series("a", [1, 4, 6]) + assert_series_equal(out, expected) + out = a.list.first() + assert_series_equal(out, expected) + out = pl.select(pl.lit(a).list.first()).to_series() + assert_series_equal(out, expected) + + out = a.list.get(-1, null_on_oob=True) expected = pl.Series("a", [3, 5, 9]) assert_series_equal(out, expected) out = a.list.last() @@ -31,24 +98,24 @@ def test_list_arr_get() -> None: assert_series_equal(out, expected) # Out of bounds index. - out = a.list.get(3) + out = a.list.get(3, null_on_oob=True) expected = pl.Series("a", [None, None, 9]) assert_series_equal(out, expected) # Null index. - out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None))) + out_df = a.to_frame().select(pl.col.a.list.get(pl.lit(None), null_on_oob=True)) expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() assert_frame_equal(out_df, expected_df) a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) - out = a.list.get(-3) + out = a.list.get(-3, null_on_oob=True) expected = pl.Series("a", [1, None, 7]) assert_series_equal(out, expected) assert pl.DataFrame( {"a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]]} ).with_columns( - [pl.col("a").list.get(i).alias(f"get_{i}") for i in range(4)] + [pl.col("a").list.get(i, null_on_oob=True).alias(f"get_{i}") for i in range(4)] ).to_dict(as_series=False) == { "a": [[1], [2], [3], [4, 5, 6], [7, 8, 9], [None, 11]], "get_0": [1, 2, 3, 4, 7, None], @@ -60,9 +127,9 @@ def test_list_arr_get() -> None: # get by indexes where some are out of bounds df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []], "indexes": [-2, 1, -3, 0]}) - assert df.select([pl.col("cars").list.get("indexes")]).to_dict(as_series=False) == { - "cars": [2, 3, None, None] - } + assert df.select([pl.col("cars").list.get("indexes", null_on_oob=True)]).to_dict( + as_series=False + ) == {"cars": [2, 3, None, None]} # exact on oob boundary df = pl.DataFrame( { @@ -71,12 +138,12 @@ def test_list_arr_get() -> None: } ) - assert df.select(pl.col("lists").list.get(3)).to_dict(as_series=False) == { - "lists": [None, None, 4] - } - assert df.select(pl.col("lists").list.get(pl.col("index"))).to_dict( + assert df.select(pl.col("lists").list.get(3, null_on_oob=True)).to_dict( as_series=False ) == {"lists": [None, None, 4]} + assert df.select( + pl.col("lists").list.get(pl.col("index"), null_on_oob=True) + ).to_dict(as_series=False) == {"lists": [None, None, 4]} def test_list_categorical_get() -> None: @@ -88,7 +155,9 @@ def test_list_categorical_get() -> None: } ) expected = pl.Series("actions", ["a", "c", None, None], dtype=pl.Categorical) - assert_series_equal(df["actions"].list.get(0), expected, categorical_as_str=True) + assert_series_equal( + df["actions"].list.get(0, null_on_oob=True), expected, categorical_as_str=True + ) def test_contains() -> None: @@ -597,7 +666,7 @@ def test_select_from_list_to_struct_11143() -> None: def test_list_arr_get_8810() -> None: assert pl.DataFrame(pl.Series("a", [None], pl.List(pl.Int64))).select( - pl.col("a").list.get(0) + pl.col("a").list.get(0, null_on_oob=True) ).to_dict(as_series=False) == {"a": [None]} From b1eaff37b87ef91a358288fba47fb6895bac9c33 Mon Sep 17 00:00:00 2001 From: Thomas <37830237+thomaslin2020@users.noreply.github.com> Date: Mon, 1 Apr 2024 07:23:31 -0400 Subject: [PATCH 27/30] docs(python): Added example for `explode` mapping strategy in `pl.Expr.over` (#15402) --- py-polars/polars/expr/expr.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 39ddacbb99d1..617459587ab8 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -3245,8 +3245,11 @@ def over( Join the groups as 'List' to the row positions. warning: this can be memory intensive. - explode - Don't do any mapping, but simply flatten the group. - This only makes sense if the input data is sorted. + Explodes the grouped data into new rows, similar to the results of + `group_by` + `agg` + `explode`. Sorting of the given groups is required + if the groups are not part of the window operation for the operation, + otherwise the result would not make sense. This operation changes the + number of rows. Examples -------- @@ -3328,6 +3331,26 @@ def over( │ b ┆ 5 ┆ 2 ┆ 1 │ │ b ┆ 3 ┆ 1 ┆ 1 │ └─────┴─────┴─────┴───────┘ + + Aggregate values from each group using `mapping_strategy="explode"`. + + >>> df.select( + ... pl.col("a").head(2).over("a", mapping_strategy="explode"), + ... pl.col("b").sort_by("b").head(2).over("a", mapping_strategy="explode"), + ... pl.col("c").sort_by("b").head(2).over("a", mapping_strategy="explode"), + ... ) + shape: (4, 3) + ┌─────┬─────┬─────┐ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ str ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ a ┆ 1 ┆ 5 │ + │ a ┆ 2 ┆ 4 │ + │ b ┆ 3 ┆ 3 │ + │ b ┆ 3 ┆ 1 │ + └─────┴─────┴─────┘ + """ exprs = parse_as_list_of_expressions(expr, *more_exprs) return self._from_pyexpr(self._pyexpr.over(exprs, mapping_strategy)) From 82f717b45ecb3c580cd788a41d717dd2f2ae539a Mon Sep 17 00:00:00 2001 From: Lava <34743145+CanglongCl@users.noreply.github.com> Date: Mon, 1 Apr 2024 05:23:00 -0700 Subject: [PATCH 28/30] feat(python): Add `read_clipboard` and `DataFrame.write_clipboard` (#15272) --- Cargo.lock | 103 +++++++++++++++++++++++++ Cargo.toml | 1 + py-polars/Cargo.toml | 3 + py-polars/docs/source/reference/io.rst | 8 ++ py-polars/polars/__init__.py | 2 + py-polars/polars/dataframe/frame.py | 22 ++++++ py-polars/polars/io/__init__.py | 2 + py-polars/polars/io/clipboard.py | 36 +++++++++ py-polars/src/functions/io.rs | 24 ++++++ py-polars/src/lib.rs | 6 ++ 10 files changed, 207 insertions(+) create mode 100644 py-polars/polars/io/clipboard.py diff --git a/Cargo.lock b/Cargo.lock index 93c1a6c43eaa..8be56fa3464f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,22 @@ dependencies = [ "uuid", ] +[[package]] +name = "arboard" +version = "3.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2041f1943049c7978768d84e6d0fd95de98b76d6c4727b09e78ec253d29fa58" +dependencies = [ + "clipboard-win", + "log", + "objc", + "objc-foundation", + "objc_id", + "parking_lot", + "thiserror", + "x11rb", +] + [[package]] name = "argminmax" version = "0.6.2" @@ -712,6 +728,12 @@ dependencies = [ "serde", ] +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -925,6 +947,15 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "clipboard-win" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d517d4b86184dbb111d3556a10f1c8a04da7428d2987bf1081602bf11c3aa9ee" +dependencies = [ + "error-code", +] + [[package]] name = "cmake" version = "0.1.50" @@ -1311,6 +1342,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "error-code" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0474425d51df81997e2f90a21591180b38eccf27292d755f3e30750225c175b" + [[package]] name = "ethnum" version = "1.5.0" @@ -1485,6 +1522,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "gethostname" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0176e0459c2e4a1fe232f984bca6890e681076abb9934f6cea7c326f3fc47818" +dependencies = [ + "libc", + "windows-targets 0.48.5", +] + [[package]] name = "getrandom" version = "0.2.12" @@ -2110,6 +2157,15 @@ dependencies = [ "libc", ] +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matrixmultiply" version = "0.3.8" @@ -2349,6 +2405,35 @@ dependencies = [ "rustc-hash", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + +[[package]] +name = "objc-foundation" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1add1b659e36c9607c7aab864a76c7a4c2760cd0cd2e120f3fb8b952c7e22bf9" +dependencies = [ + "block", + "objc", + "objc_id", +] + +[[package]] +name = "objc_id" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c92d4ddb4bd7b50d730c215ff871754d0da6b2178849f8a2a2ab69712d0c073b" +dependencies = [ + "objc", +] + [[package]] name = "object" version = "0.32.2" @@ -3108,6 +3193,7 @@ name = "py-polars" version = "0.20.17" dependencies = [ "ahash", + "arboard", "built", "ciborium", "either", @@ -4788,6 +4874,23 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "x11rb" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8f25ead8c7e4cba123243a6367da5d3990e0d3affa708ea19dce96356bd9f1a" +dependencies = [ + "gethostname", + "rustix", + "x11rb-protocol", +] + +[[package]] +name = "x11rb-protocol" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e63e71c4b8bd9ffec2c963173a4dc4cbde9ee96961d4fcb4429db9929b606c34" + [[package]] name = "xmlparser" version = "0.13.6" diff --git a/Cargo.toml b/Cargo.toml index 7d75530dc2e1..ef4b780ce9bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ version_check = "0.9.4" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } zstd = "0.13" uuid = { version = "1.7.0", features = ["v4"] } +arboard = { version = "3.3.2", default-features = false } polars = { version = "0.38.3", path = "crates/polars", default-features = false } polars-compute = { version = "0.38.3", path = "crates/polars-compute", default-features = false } diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index a48000225eab..85bcbf976fdd 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -17,6 +17,7 @@ polars-plan = { workspace = true } polars-utils = { workspace = true } ahash = { workspace = true } +arboard = { workspace = true, optional = true } ciborium = { workspace = true } either = { workspace = true } itoa = { workspace = true } @@ -126,6 +127,7 @@ search_sorted = ["polars/search_sorted"] decompress = ["polars/decompress-fast"] regex = ["polars/regex"] csv = ["polars/csv"] +clipboard = ["arboard"] object = ["polars/object"] extract_jsonpath = ["polars/extract_jsonpath"] pivot = ["polars/pivot"] @@ -204,6 +206,7 @@ io = [ "avro", "csv", "cloud", + "clipboard", ] optimizations = [ diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index d3c45469f94a..1f088958a3c0 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -11,6 +11,14 @@ Avro read_avro DataFrame.write_avro +Clipboard +~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + read_clipboard + DataFrame.write_clipboard + CSV ~~~ .. autosummary:: diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 4691a6f13cc2..492c3e437f63 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -180,6 +180,7 @@ ) from polars.io import ( read_avro, + read_clipboard, read_csv, read_csv_batched, read_database, @@ -316,6 +317,7 @@ "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", + "read_clipboard", # polars.stringcache "StringCache", "disable_string_cache", diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 2d0e9cde550e..a34c4aae5b84 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -118,6 +118,7 @@ with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame from polars.polars import dtype_str_repr as _dtype_str_repr + from polars.polars import write_clipboard_string as _write_clipboard_string if TYPE_CHECKING: import sys @@ -2595,6 +2596,27 @@ def write_csv( return None + def write_clipboard(self, *, separator: str = "\t", **kwargs: Any) -> None: + """ + Copy `DataFrame` in csv format to the system clipboard with `write_csv`. + + Useful for pasting into Excel or other similar spreadsheet software. + + Parameters + ---------- + separator + Separate CSV fields with this symbol. + kwargs + Additional arguments to pass to `write_csv`. + + See Also + -------- + polars.read_clipboard: Read a DataFrame from the clipboard. + write_csv: Write to comma-separated values (CSV) file. + """ + result: str = self.write_csv(file=None, separator=separator, **kwargs) + _write_clipboard_string(result) + def write_avro( self, file: str | Path | IO[bytes], diff --git a/py-polars/polars/io/__init__.py b/py-polars/polars/io/__init__.py index 395f15bd4c94..35f61f1fb596 100644 --- a/py-polars/polars/io/__init__.py +++ b/py-polars/polars/io/__init__.py @@ -1,6 +1,7 @@ """Functions for reading data.""" from polars.io.avro import read_avro +from polars.io.clipboard import read_clipboard from polars.io.csv import read_csv, read_csv_batched, scan_csv from polars.io.database import read_database, read_database_uri from polars.io.delta import read_delta, scan_delta @@ -35,4 +36,5 @@ "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", + "read_clipboard", ] diff --git a/py-polars/polars/io/clipboard.py b/py-polars/polars/io/clipboard.py new file mode 100644 index 000000000000..aa441ded0429 --- /dev/null +++ b/py-polars/polars/io/clipboard.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import contextlib +from io import StringIO +from typing import TYPE_CHECKING, Any + +from polars.io.csv.functions import read_csv + +with contextlib.suppress(ImportError): + from polars.polars import read_clipboard_string as _read_clipboard_string + +if TYPE_CHECKING: + from polars import DataFrame + + +def read_clipboard(separator: str = "\t", **kwargs: Any) -> DataFrame: + """ + Read text from clipboard and pass to `read_csv`. + + Useful for reading data copied from Excel or other similar spreadsheet software. + + Parameters + ---------- + separator + Single byte character to use as separator parsing csv from clipboard. + kwargs + Additional arguments passed to `read_csv`. + + See Also + -------- + read_csv : Read a csv file into a DataFrame. + DataFrame.write_clipboard : Write a DataFrame to the clipboard. + """ + csv_string: str = _read_clipboard_string() + io_string = StringIO(csv_string) + return read_csv(source=io_string, separator=separator, **kwargs) diff --git a/py-polars/src/functions/io.rs b/py-polars/src/functions/io.rs index 4f79dc46f873..212d16b19210 100644 --- a/py-polars/src/functions/io.rs +++ b/py-polars/src/functions/io.rs @@ -56,3 +56,27 @@ fn fields_to_pydict(fields: &Vec, dict: &PyDict, py: Python) -> PyResult< } Ok(()) } + +#[cfg(feature = "clipboard")] +#[pyfunction] +pub fn read_clipboard_string() -> PyResult { + use arboard; + let mut clipboard = + arboard::Clipboard::new().map_err(|e| PyPolarsErr::Other(format!("{e}")))?; + let result = clipboard + .get_text() + .map_err(|e| PyPolarsErr::Other(format!("{e}")))?; + Ok(result) +} + +#[cfg(feature = "clipboard")] +#[pyfunction] +pub fn write_clipboard_string(s: &str) -> PyResult<()> { + use arboard; + let mut clipboard = + arboard::Clipboard::new().map_err(|e| PyPolarsErr::Other(format!("{e}")))?; + clipboard + .set_text(s) + .map_err(|e| PyPolarsErr::Other(format!("{e}")))?; + Ok(()) +} diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index e1c978f595d7..787e747a8b56 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -217,6 +217,12 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { #[cfg(feature = "parquet")] m.add_wrapped(wrap_pyfunction!(functions::read_parquet_schema)) .unwrap(); + #[cfg(feature = "clipboard")] + m.add_wrapped(wrap_pyfunction!(functions::read_clipboard_string)) + .unwrap(); + #[cfg(feature = "clipboard")] + m.add_wrapped(wrap_pyfunction!(functions::write_clipboard_string)) + .unwrap(); // Functions - meta m.add_wrapped(wrap_pyfunction!(functions::get_index_type)) From 758b55a58010b45c4b4e06ee500d1e8b16cba547 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 1 Apr 2024 14:23:13 +0200 Subject: [PATCH 29/30] perf: Make LogicalPlan immutable (#15416) --- crates/polars-lazy/src/frame/mod.rs | 6 +- .../executors/group_by_partitioned.rs | 2 +- .../polars-plan/src/logical_plan/builder.rs | 42 +++++++------- .../src/logical_plan/conversion.rs | 55 ++++++++++--------- crates/polars-plan/src/logical_plan/mod.rs | 28 +++++----- 5 files changed, 67 insertions(+), 66 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index cc96ce2173d3..f08646959e3f 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -751,7 +751,7 @@ impl LazyFrame { ) -> PolarsResult<()> { self.opt_state.streaming = true; self.logical_plan = LogicalPlan::Sink { - input: Box::new(self.logical_plan), + input: Arc::new(self.logical_plan), payload: SinkType::Cloud { uri: Arc::new(uri), cloud_options, @@ -806,7 +806,7 @@ impl LazyFrame { fn sink(mut self, payload: SinkType, msg_alternative: &str) -> Result<(), PolarsError> { self.opt_state.streaming = true; self.logical_plan = LogicalPlan::Sink { - input: Box::new(self.logical_plan), + input: Arc::new(self.logical_plan), payload, }; let (mut state, mut physical_plan, is_streaming) = self.prepare_collect(true)?; @@ -1846,7 +1846,7 @@ impl LazyGroupBy { let options = GroupbyOptions { slice: None }; let lp = LogicalPlan::Aggregate { - input: Box::new(self.logical_plan), + input: Arc::new(self.logical_plan), keys: Arc::new(self.keys), aggs: vec![], schema, diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs index d10a09dffcbb..0f5f1e50901b 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs @@ -257,7 +257,7 @@ impl PartitionGroupByExec { } .into(); let lp = LogicalPlan::Aggregate { - input: Box::new(original_df.lazy().logical_plan), + input: Arc::new(original_df.lazy().logical_plan), keys: Arc::new(std::mem::take(&mut self.keys)), aggs: std::mem::take(&mut self.aggs), schema: self.output_schema.clone(), diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index 3f1fba60c172..9a257eb95b9f 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -68,7 +68,7 @@ macro_rules! raise_err { let err = $err.wrap_msg(&format_err_outer); LogicalPlan::Error { - input: Box::new(input), + input: Arc::new(input), err: err.into(), // PolarsError -> ErrorState } }, @@ -424,7 +424,7 @@ impl LogicalPlanBuilder { } pub fn cache(self) -> Self { - let input = Box::new(self.0); + let input = Arc::new(self.0); let id = input.as_ref() as *const LogicalPlan as usize; LogicalPlan::Cache { input, @@ -461,7 +461,7 @@ impl LogicalPlanBuilder { } else { LogicalPlan::Projection { expr: columns, - input: Box::new(self.0), + input: Arc::new(self.0), schema: Arc::new(output_schema), options: ProjectionOptions { run_parallel: false, @@ -486,7 +486,7 @@ impl LogicalPlanBuilder { } else { LogicalPlan::Projection { expr: exprs, - input: Box::new(self.0), + input: Arc::new(self.0), schema: Arc::new(schema), options, } @@ -576,7 +576,7 @@ impl LogicalPlanBuilder { } LogicalPlan::HStack { - input: Box::new(self.0), + input: Arc::new(self.0), exprs, schema: Arc::new(new_schema), options, @@ -586,7 +586,7 @@ impl LogicalPlanBuilder { pub fn add_err(self, err: PolarsError) -> Self { LogicalPlan::Error { - input: Box::new(self.0), + input: Arc::new(self.0), err: err.into(), } .into() @@ -608,7 +608,7 @@ impl LogicalPlanBuilder { } } LogicalPlan::ExtContext { - input: Box::new(self.0), + input: Arc::new(self.0), contexts, schema: Arc::new(schema), } @@ -692,7 +692,7 @@ impl LogicalPlanBuilder { LogicalPlan::Selection { predicate, - input: Box::new(self.0), + input: Arc::new(self.0), } .into() } @@ -777,7 +777,7 @@ impl LogicalPlanBuilder { }; LogicalPlan::Aggregate { - input: Box::new(self.0), + input: Arc::new(self.0), keys: Arc::new(keys), aggs, schema: Arc::new(schema), @@ -814,7 +814,7 @@ impl LogicalPlanBuilder { let schema = try_delayed!(self.0.schema(), &self.0, into); let by_column = try_delayed!(rewrite_projections(by_column, &schema, &[]), &self.0, into); LogicalPlan::Sort { - input: Box::new(self.0), + input: Arc::new(self.0), by_column, args: SortArguments { descending, @@ -846,7 +846,7 @@ impl LogicalPlanBuilder { try_delayed!(explode_schema(&mut schema, &columns), &self.0, into); LogicalPlan::MapFunction { - input: Box::new(self.0), + input: Arc::new(self.0), function: FunctionNode::Explode { columns, schema: Arc::new(schema), @@ -859,7 +859,7 @@ impl LogicalPlanBuilder { let schema = try_delayed!(self.0.schema(), &self.0, into); let schema = det_melt_schema(&args, &schema); LogicalPlan::MapFunction { - input: Box::new(self.0), + input: Arc::new(self.0), function: FunctionNode::Melt { args, schema }, } .into() @@ -871,7 +871,7 @@ impl LogicalPlanBuilder { row_index_schema(schema_mut, name); LogicalPlan::MapFunction { - input: Box::new(self.0), + input: Arc::new(self.0), function: FunctionNode::RowIndex { name: ColumnName::from(name), offset, @@ -883,7 +883,7 @@ impl LogicalPlanBuilder { pub fn distinct(self, options: DistinctOptions) -> Self { LogicalPlan::Distinct { - input: Box::new(self.0), + input: Arc::new(self.0), options, } .into() @@ -891,7 +891,7 @@ impl LogicalPlanBuilder { pub fn slice(self, offset: i64, len: IdxSize) -> Self { LogicalPlan::Slice { - input: Box::new(self.0), + input: Arc::new(self.0), offset, len, } @@ -908,7 +908,7 @@ impl LogicalPlanBuilder { for e in left_on.iter().chain(right_on.iter()) { if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { return LogicalPlan::Error { - input: Box::new(self.0), + input: Arc::new(self.0), err: polars_err!( ComputeError: "'alias' is not allowed in a join key, use 'with_columns' first", @@ -929,8 +929,8 @@ impl LogicalPlanBuilder { ); LogicalPlan::Join { - input_left: Box::new(self.0), - input_right: Box::new(other), + input_left: Arc::new(self.0), + input_right: Arc::new(other), schema, left_on, right_on, @@ -940,7 +940,7 @@ impl LogicalPlanBuilder { } pub fn map_private(self, function: FunctionNode) -> Self { LogicalPlan::MapFunction { - input: Box::new(self.0), + input: Arc::new(self.0), function, } .into() @@ -955,7 +955,7 @@ impl LogicalPlanBuilder { validate_output: bool, ) -> Self { LogicalPlan::MapFunction { - input: Box::new(self.0), + input: Arc::new(self.0), function: FunctionNode::OpaquePython { function, schema, @@ -981,7 +981,7 @@ impl LogicalPlanBuilder { let function = Arc::new(function); LogicalPlan::MapFunction { - input: Box::new(self.0), + input: Arc::new(self.0), function: FunctionNode::Opaque { function, schema, diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index c78e8c5b56ea..d3fd74f6830a 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -274,6 +274,7 @@ pub fn to_alp( expr_arena: &mut Arena, lp_arena: &mut Arena, ) -> PolarsResult { + let owned = Arc::unwrap_or_clone; let v = match lp { LogicalPlan::Scan { file_info, @@ -317,7 +318,7 @@ pub fn to_alp( } }, LogicalPlan::Selection { input, predicate } => { - let i = to_alp(*input, expr_arena, lp_arena)?; + let i = to_alp(owned(input), expr_arena, lp_arena)?; let p = to_expr_ir(predicate, expr_arena); ALogicalPlan::Selection { input: i, @@ -325,7 +326,7 @@ pub fn to_alp( } }, LogicalPlan::Slice { input, offset, len } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::Slice { input, offset, len } }, LogicalPlan::DataFrameScan { @@ -349,7 +350,7 @@ pub fn to_alp( } => { let eirs = to_expr_irs(expr, expr_arena); let expr = eirs.into(); - let i = to_alp(*input, expr_arena, lp_arena)?; + let i = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::Projection { expr, input: i, @@ -362,7 +363,7 @@ pub fn to_alp( by_column, args, } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; let by_column = to_expr_irs(by_column, expr_arena); ALogicalPlan::Sort { input, @@ -375,7 +376,7 @@ pub fn to_alp( id, cache_hits, } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::Cache { input, id, @@ -391,7 +392,7 @@ pub fn to_alp( maintain_order, options, } => { - let i = to_alp(*input, expr_arena, lp_arena)?; + let i = to_alp(owned(input), expr_arena, lp_arena)?; let aggs = to_expr_irs(aggs, expr_arena); let keys = keys.convert(|e| to_expr_ir(e.clone(), expr_arena)); @@ -413,8 +414,8 @@ pub fn to_alp( right_on, options, } => { - let input_left = to_alp(*input_left, expr_arena, lp_arena)?; - let input_right = to_alp(*input_right, expr_arena, lp_arena)?; + let input_left = to_alp(owned(input_left), expr_arena, lp_arena)?; + let input_right = to_alp(owned(input_right), expr_arena, lp_arena)?; let left_on = to_expr_irs_ignore_alias(left_on, expr_arena); let right_on = to_expr_irs_ignore_alias(right_on, expr_arena); @@ -436,7 +437,7 @@ pub fn to_alp( } => { let eirs = to_expr_irs(exprs, expr_arena); let exprs = eirs.into(); - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::HStack { input, exprs, @@ -445,11 +446,11 @@ pub fn to_alp( } }, LogicalPlan::Distinct { input, options } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::Distinct { input, options } }, LogicalPlan::MapFunction { input, function } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::MapFunction { input, function } }, LogicalPlan::Error { err, .. } => { @@ -462,7 +463,7 @@ pub fn to_alp( contexts, schema, } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; let contexts = contexts .into_iter() .map(|lp| to_alp(lp, expr_arena, lp_arena)) @@ -474,7 +475,7 @@ pub fn to_alp( } }, LogicalPlan::Sink { input, payload } => { - let input = to_alp(*input, expr_arena, lp_arena)?; + let input = to_alp(owned(input), expr_arena, lp_arena)?; ALogicalPlan::Sink { input, payload } }, }; @@ -770,7 +771,7 @@ impl ALogicalPlan { ALogicalPlan::Slice { input, offset, len } => { let lp = convert_to_lp(input, lp_arena); LogicalPlan::Slice { - input: Box::new(lp), + input: Arc::new(lp), offset, len, } @@ -779,7 +780,7 @@ impl ALogicalPlan { let lp = convert_to_lp(input, lp_arena); let predicate = predicate.to_expr(expr_arena); LogicalPlan::Selection { - input: Box::new(lp), + input: Arc::new(lp), predicate, } }, @@ -806,7 +807,7 @@ impl ALogicalPlan { let expr = expr_irs_to_exprs(expr.all_exprs(), expr_arena); LogicalPlan::Projection { expr, - input: Box::new(i), + input: Arc::new(i), schema, options, } @@ -819,7 +820,7 @@ impl ALogicalPlan { .collect::>(); LogicalPlan::Projection { expr, - input: Box::new(input), + input: Arc::new(input), schema: columns.clone(), options: Default::default(), } @@ -829,7 +830,7 @@ impl ALogicalPlan { by_column, args, } => { - let input = Box::new(convert_to_lp(input, lp_arena)); + let input = Arc::new(convert_to_lp(input, lp_arena)); let by_column = expr_irs_to_exprs(by_column, expr_arena); LogicalPlan::Sort { input, @@ -842,7 +843,7 @@ impl ALogicalPlan { id, cache_hits, } => { - let input = Box::new(convert_to_lp(input, lp_arena)); + let input = Arc::new(convert_to_lp(input, lp_arena)); LogicalPlan::Cache { input, id, @@ -863,7 +864,7 @@ impl ALogicalPlan { let aggs = expr_irs_to_exprs(aggs, expr_arena); LogicalPlan::Aggregate { - input: Box::new(i), + input: Arc::new(i), keys, aggs, schema, @@ -887,8 +888,8 @@ impl ALogicalPlan { let right_on = expr_irs_to_exprs(right_on, expr_arena); LogicalPlan::Join { - input_left: Box::new(i_l), - input_right: Box::new(i_r), + input_left: Arc::new(i_l), + input_right: Arc::new(i_r), schema, left_on, right_on, @@ -905,7 +906,7 @@ impl ALogicalPlan { let exprs = expr_irs_to_exprs(exprs.all_exprs(), expr_arena); LogicalPlan::HStack { - input: Box::new(i), + input: Arc::new(i), exprs, schema, options, @@ -914,12 +915,12 @@ impl ALogicalPlan { ALogicalPlan::Distinct { input, options } => { let i = convert_to_lp(input, lp_arena); LogicalPlan::Distinct { - input: Box::new(i), + input: Arc::new(i), options, } }, ALogicalPlan::MapFunction { input, function } => { - let input = Box::new(convert_to_lp(input, lp_arena)); + let input = Arc::new(convert_to_lp(input, lp_arena)); LogicalPlan::MapFunction { input, function } }, ALogicalPlan::ExtContext { @@ -927,7 +928,7 @@ impl ALogicalPlan { contexts, schema, } => { - let input = Box::new(convert_to_lp(input, lp_arena)); + let input = Arc::new(convert_to_lp(input, lp_arena)); let contexts = contexts .into_iter() .map(|node| convert_to_lp(node, lp_arena)) @@ -939,7 +940,7 @@ impl ALogicalPlan { } }, ALogicalPlan::Sink { input, payload } => { - let input = Box::new(convert_to_lp(input, lp_arena)); + let input = Arc::new(convert_to_lp(input, lp_arena)); LogicalPlan::Sink { input, payload } }, ALogicalPlan::Invalid => unreachable!(), diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index 1b832b03ec57..4a6a910a9d07 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -135,12 +135,12 @@ pub enum LogicalPlan { PythonScan { options: PythonOptions }, /// Filter on a boolean mask Selection { - input: Box, + input: Arc, predicate: Expr, }, /// Cache the input at this point in the LP Cache { - input: Box, + input: Arc, id: usize, cache_hits: u32, }, @@ -164,13 +164,13 @@ pub enum LogicalPlan { /// Column selection Projection { expr: Vec, - input: Box, + input: Arc, schema: SchemaRef, options: ProjectionOptions, }, /// Groupby aggregation Aggregate { - input: Box, + input: Arc, keys: Arc>, aggs: Vec, schema: SchemaRef, @@ -181,8 +181,8 @@ pub enum LogicalPlan { }, /// Join operation Join { - input_left: Box, - input_right: Box, + input_left: Arc, + input_right: Arc, schema: SchemaRef, left_on: Vec, right_on: Vec, @@ -190,31 +190,31 @@ pub enum LogicalPlan { }, /// Adding columns to the table without a Join HStack { - input: Box, + input: Arc, exprs: Vec, schema: SchemaRef, options: ProjectionOptions, }, /// Remove duplicates from the table Distinct { - input: Box, + input: Arc, options: DistinctOptions, }, /// Sort the table Sort { - input: Box, + input: Arc, by_column: Vec, args: SortArguments, }, /// Slice the table Slice { - input: Box, + input: Arc, offset: i64, len: IdxSize, }, /// A (User Defined) Function MapFunction { - input: Box, + input: Arc, function: FunctionNode, }, Union { @@ -230,17 +230,17 @@ pub enum LogicalPlan { /// Catches errors and throws them later #[cfg_attr(feature = "serde", serde(skip))] Error { - input: Box, + input: Arc, err: ErrorState, }, /// This allows expressions to access other tables ExtContext { - input: Box, + input: Arc, contexts: Vec, schema: SchemaRef, }, Sink { - input: Box, + input: Arc, payload: SinkType, }, } From b0ece1e97aec58ab7ab7ce5a29e73c1dc54baa5c Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Mon, 1 Apr 2024 21:40:56 +0800 Subject: [PATCH 30/30] feat: Supports `explode_by_offsets` for decimal (#15417) --- .../src/series/implementations/decimal.rs | 11 +++++++ .../tests/unit/datatypes/test_decimal.py | 29 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 1c3e6f61566c..59163ab5cbd1 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -153,6 +153,17 @@ impl private::PrivateSeries for SeriesWrap { fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } + + fn explode_by_offsets(&self, offsets: &[i64]) -> Series { + self.0 + .explode_by_offsets(offsets) + .decimal() + .unwrap() + .as_ref() + .clone() + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } } impl SeriesTrait for SeriesWrap { diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 14201fbe36b4..b81031129df4 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -409,3 +409,32 @@ def test_decimal_list_get_13847() -> None: out = df.select(pl.col("a").list.get(0)) expected = pl.DataFrame({"a": [D("1.1"), D("2.1")]}) assert_frame_equal(out, expected) + + +def test_decimal_explode() -> None: + with pl.Config() as cfg: + cfg.activate_decimals() + + nested_decimal_df = pl.DataFrame( + { + "bar": [[D("3.4"), D("3.4")], [D("4.5")]], + } + ) + df = nested_decimal_df.explode("bar") + expected_df = pl.DataFrame( + { + "bar": [D("3.4"), D("3.4"), D("4.5")], + } + ) + assert_frame_equal(df, expected_df) + + # test group-by head #15330 + df = pl.DataFrame( + { + "foo": [1, 1, 2], + "bar": [D("3.4"), D("3.4"), D("4.5")], + } + ) + head_df = df.group_by("foo", maintain_order=True).head(1) + expected_df = pl.DataFrame({"foo": [1, 2], "bar": [D("3.4"), D("4.5")]}) + assert_frame_equal(head_df, expected_df)