From 665da25b16b15814237e69338acc47ba8781bc24 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 15 May 2024 18:57:43 +0100 Subject: [PATCH] feat: sort on behalf of user --- Cargo.lock | 1 + crates/polars-time/Cargo.toml | 1 + .../chunkedarray/rolling_window/dispatch.rs | 88 +++++--- .../src/chunkedarray/rolling_window/mod.rs | 3 - .../rolling_kernels/no_nulls.rs | 201 ++++++++++++++++-- py-polars/polars/expr/expr.py | 182 +++++++++------- py-polars/src/expr/rolling.rs | 32 +-- .../unit/operations/rolling/test_rolling.py | 62 ++++-- 8 files changed, 397 insertions(+), 173 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e02e9957631ce..3b40b0591619e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3150,6 +3150,7 @@ name = "polars-time" version = "0.39.2" dependencies = [ "atoi", + "bytemuck", "chrono", "chrono-tz", "now", diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index 14c0d02778e7d..6bf5c00e03619 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -23,6 +23,7 @@ once_cell = { workspace = true } regex = { workspace = true } serde = { workspace = true, optional = true } smartstring = { workspace = true } +bytemuck = {workspace = true} [dev-dependencies] polars-ops = { workspace = true, features = ["abs"] } diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index aefc54b83fda7..3ee24266444e6 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -73,6 +73,7 @@ fn rolling_agg_by( TimeUnit, Option<&TimeZone>, DynArgs, + Option<&[IdxSize]>, ) -> PolarsResult, ) -> PolarsResult where @@ -83,20 +84,9 @@ where } let ca = ca.rechunk(); let by = by.rechunk(); + polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column"); ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?; polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); - if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted { - polars_warn!(format!( - "Series is not known to be sorted by `by` column in `rolling_*_by` operation.\n\ - \n\ - To silence this warning, you may want to try:\n\ - - sorting your data by your `by` column beforehand;\n\ - - setting `.set_sorted()` if you already know your data is sorted;\n\ - - passing `warn_if_unsorted=False` if this warning is a false-positive\n \ - (this is known to happen when combining rolling aggregations with `over`);\n\n\ - before passing calling the rolling aggregation function.\n", - )); - } let (by, tz) = match by.dtype() { DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), DataType::Date => ( @@ -109,32 +99,60 @@ where "date/datetime"), }; let by = by.datetime().unwrap(); - let by_values = by.cont_slice().map_err(|_| { - polars_err!( - ComputeError: - "`by` column should not have null values in 'rolling by' expression" - ) - })?; let tu = by.time_unit(); - let arr = ca.downcast_iter().next().unwrap(); - if arr.null_count() > 0 { - polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") - } - let values = arr.values().as_slice(); let func = rolling_agg_fn_dynamic; - - let arr = func( - values, - options.window_size, - by_values, - options.closed_window, - options.min_periods, - tu, - tz.as_ref(), - options.fn_params, - )?; - Series::try_from((ca.name(), arr)) + let out: ArrayRef = if matches!(by.is_sorted_flag(), IsSorted::Ascending) { + let arr = ca.downcast_iter().next().unwrap(); + if arr.null_count() > 0 { + polars_bail!(InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") + } + let by_values = by.cont_slice().map_err(|_| { + polars_err!( + ComputeError: + "`by` column should not have null values in 'rolling by' expression" + ) + })?; + let values = arr.values().as_slice(); + func( + values, + options.window_size, + by_values, + options.closed_window, + options.min_periods, + tu, + tz.as_ref(), + options.fn_params, + None, + )? + } else { + let sorting_indices = by.arg_sort(Default::default()); + let ca = unsafe { ca.take_unchecked(&sorting_indices) }; + let by = unsafe { by.take_unchecked(&sorting_indices) }; + let arr = ca.downcast_iter().next().unwrap(); + if arr.null_count() > 0 { + polars_bail!(InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") + } + let by_values = by.cont_slice().map_err(|_| { + polars_err!( + ComputeError: + "`by` column should not have null values in 'rolling by' expression" + ) + })?; + let values = arr.values().as_slice(); + func( + values, + options.window_size, + by_values, + options.closed_window, + options.min_periods, + tu, + tz.as_ref(), + options.fn_params, + Some(sorting_indices.cont_slice().unwrap()), + )? + }; + Series::try_from((ca.name(), out)) } pub trait SeriesOpsTime: AsSeries { diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index 0b2909b5dda47..7f1e95e2d46c4 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -23,8 +23,6 @@ pub struct RollingOptionsDynamicWindow { /// Optional parameters for the rolling function #[cfg_attr(feature = "serde", serde(skip))] pub fn_params: DynArgs, - /// Warn if data is not known to be sorted by `by` column (if passed) - pub warn_if_unsorted: bool, } #[cfg(feature = "rolling_window_by")] @@ -33,7 +31,6 @@ impl PartialEq for RollingOptionsDynamicWindow { self.window_size == other.window_size && self.min_periods == other.min_periods && self.closed_window == other.closed_window - && self.warn_if_unsorted == other.warn_if_unsorted && self.fn_params.is_none() && other.fn_params.is_none() } diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs index abd4eadffc79e..0853e1a546890 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs @@ -1,11 +1,14 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::legacy::kernels::rolling::no_nulls::{self, RollingAggWindowNoNulls}; +use bytemuck::allocation::zeroed_vec; #[cfg(feature = "timezones")] use chrono_tz::Tz; use super::*; -// Use an aggregation window that maintains the state -pub(crate) fn rolling_apply_agg_window<'a, Agg, T, O>( +// Use an aggregation window that maintains the state. +// Fastpath if values were known to already be sorted by time. +pub(crate) fn rolling_apply_agg_window_sorted<'a, Agg, T, O>( values: &'a [T], offsets: O, min_periods: usize, @@ -50,6 +53,77 @@ where Ok(Box::new(out)) } +// Use an aggregation window that maintains the state +pub(crate) fn rolling_apply_agg_window<'a, Agg, T, O>( + values: &'a [T], + offsets: O, + min_periods: usize, + params: DynArgs, + sorting_indices: Option<&[IdxSize]>, +) -> PolarsResult +where + // items (offset, len) -> so offsets are offset, offset + len + Agg: RollingAggWindowNoNulls<'a, T>, + O: Iterator> + TrustedLen, + T: Debug + IsFloat + NativeType, +{ + if values.is_empty() { + let out: Vec = vec![]; + return Ok(Box::new(PrimitiveArray::new( + T::PRIMITIVE.into(), + out.into(), + None, + ))); + } + let sorting_indices = sorting_indices.expect("`sorting_indices` should have been set"); + // start with a dummy index, will be overwritten on first iteration. + let mut agg_window = Agg::new(values, 0, 0, params); + + let mut out = zeroed_vec(values.len()); + let mut null_positions = Vec::with_capacity(values.len()); + offsets.enumerate().try_for_each(|(idx, result)| { + let (start, len) = result?; + let end = start + len; + + // On the Python side, if `min_periods` wasn't specified, it is set to + // `1`. In that case, this condition is the same as checking + // `if start == end`. + if len >= (min_periods as IdxSize) { + // SAFETY: + // we are in bound + let res = unsafe { agg_window.update(start as usize, end as usize) }; + + if let Some(res) = res { + // SAFETY: `idx` is in bounds because `sorting_indices` was just taken from + // `by`, which has already been checked to be the same length as the values. + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = res; + } + } else { + null_positions.push(idx) + } + } else { + null_positions.push(idx) + } + Ok::<(), PolarsError>(()) + })?; + + let validity: Option = if null_positions.is_empty() { + None + } else { + let mut validity = MutableBitmap::with_capacity(values.len()); + validity.extend_constant(values.len(), true); + for idx in null_positions { + validity.set(idx, false) + } + Some(validity.into()) + }; + let out = PrimitiveArray::::from_vec(out).with_validity(validity); + + Ok(Box::new(out)) +} + #[allow(clippy::too_many_arguments)] pub(crate) fn rolling_min( values: &[T], @@ -60,6 +134,7 @@ pub(crate) fn rolling_min( tu: TimeUnit, tz: Option<&TimeZone>, _params: DynArgs, + sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul, @@ -69,7 +144,22 @@ where Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::().ok()), _ => group_by_values_iter(period, time, closed_window, tu, None), }?; - rolling_apply_agg_window::, _, _>(values, offset_iter, min_periods, None) + if sorting_indices.is_none() { + rolling_apply_agg_window_sorted::, _, _>( + values, + offset_iter, + min_periods, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + values, + offset_iter, + min_periods, + None, + sorting_indices, + ) + } } #[allow(clippy::too_many_arguments)] @@ -82,6 +172,7 @@ pub(crate) fn rolling_max( tu: TimeUnit, tz: Option<&TimeZone>, _params: DynArgs, + sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul, @@ -91,7 +182,22 @@ where Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::().ok()), _ => group_by_values_iter(period, time, closed_window, tu, None), }?; - rolling_apply_agg_window::, _, _>(values, offset_iter, min_periods, None) + if sorting_indices.is_none() { + rolling_apply_agg_window_sorted::, _, _>( + values, + offset_iter, + min_periods, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + values, + offset_iter, + min_periods, + None, + sorting_indices, + ) + } } #[allow(clippy::too_many_arguments)] @@ -104,6 +210,7 @@ pub(crate) fn rolling_sum( tu: TimeUnit, tz: Option<&TimeZone>, _params: DynArgs, + sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where T: NativeType + std::iter::Sum + NumCast + Mul + AddAssign + SubAssign + IsFloat, @@ -113,7 +220,22 @@ where Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::().ok()), _ => group_by_values_iter(period, time, closed_window, tu, None), }?; - rolling_apply_agg_window::, _, _>(values, offset_iter, min_periods, None) + if sorting_indices.is_none() { + rolling_apply_agg_window_sorted::, _, _>( + values, + offset_iter, + min_periods, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + values, + offset_iter, + min_periods, + None, + sorting_indices, + ) + } } #[allow(clippy::too_many_arguments)] @@ -126,6 +248,7 @@ pub(crate) fn rolling_mean( tu: TimeUnit, tz: Option<&TimeZone>, _params: DynArgs, + sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where T: NativeType + Float + std::iter::Sum + SubAssign + AddAssign + IsFloat, @@ -135,12 +258,22 @@ where Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::().ok()), _ => group_by_values_iter(period, time, closed_window, tu, None), }?; - rolling_apply_agg_window::, _, _>( - values, - offset_iter, - min_periods, - None, - ) + if sorting_indices.is_none() { + rolling_apply_agg_window_sorted::, _, _>( + values, + offset_iter, + min_periods, + None, + ) + } else { + rolling_apply_agg_window::, _, _>( + values, + offset_iter, + min_periods, + None, + sorting_indices, + ) + } } #[allow(clippy::too_many_arguments)] @@ -153,6 +286,7 @@ pub(crate) fn rolling_var( tu: TimeUnit, tz: Option<&TimeZone>, params: DynArgs, + sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where T: NativeType + Float + std::iter::Sum + SubAssign + AddAssign + IsFloat, @@ -162,12 +296,22 @@ where Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::().ok()), _ => group_by_values_iter(period, time, closed_window, tu, None), }?; - rolling_apply_agg_window::, _, _>( - values, - offset_iter, - min_periods, - params, - ) + if sorting_indices.is_none() { + rolling_apply_agg_window_sorted::, _, _>( + values, + offset_iter, + min_periods, + params, + ) + } else { + rolling_apply_agg_window::, _, _>( + values, + offset_iter, + min_periods, + params, + sorting_indices, + ) + } } #[allow(clippy::too_many_arguments)] @@ -180,6 +324,7 @@ pub(crate) fn rolling_quantile( tu: TimeUnit, tz: Option<&TimeZone>, params: DynArgs, + sorting_indices: Option<&[IdxSize]>, ) -> PolarsResult where T: NativeType + Float + std::iter::Sum + SubAssign + AddAssign + IsFloat, @@ -189,10 +334,20 @@ where Some(tz) => group_by_values_iter(period, time, closed_window, tu, tz.parse::().ok()), _ => group_by_values_iter(period, time, closed_window, tu, None), }?; - rolling_apply_agg_window::, _, _>( - values, - offset_iter, - min_periods, - params, - ) + if sorting_indices.is_none() { + rolling_apply_agg_window_sorted::, _, _>( + values, + offset_iter, + min_periods, + params, + ) + } else { + rolling_apply_agg_window::, _, _>( + values, + offset_iter, + min_periods, + params, + sorting_indices, + ) + } } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 61b413316acf7..24b39c1858b05 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6224,7 +6224,7 @@ def rolling_min_by( *, min_periods: int = 1, closed: ClosedInterval = "right", - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling min based on another column. @@ -6245,10 +6245,6 @@ def rolling_min_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6278,6 +6274,10 @@ def rolling_min_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -6340,22 +6340,26 @@ def rolling_min_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_min_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_min_by( - by, window_size, min_periods, closed, warn_if_unsorted - ) + self._pyexpr.rolling_min_by(by, window_size, min_periods, closed) ) @unstable() def rolling_max_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, closed: ClosedInterval = "right", - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling max based on another column. @@ -6376,10 +6380,6 @@ def rolling_max_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6409,6 +6409,10 @@ def rolling_max_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -6497,22 +6501,26 @@ def rolling_max_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_max_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_max_by( - by, window_size, min_periods, closed, warn_if_unsorted - ) + self._pyexpr.rolling_max_by(by, window_size, min_periods, closed) ) @unstable() def rolling_mean_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, closed: ClosedInterval = "right", - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling mean based on another column. @@ -6533,10 +6541,6 @@ def rolling_mean_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6566,6 +6570,10 @@ def rolling_mean_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -6656,6 +6664,12 @@ def rolling_mean_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_mean_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( self._pyexpr.rolling_mean_by( @@ -6663,19 +6677,18 @@ def rolling_mean_by( window_size, min_periods, closed, - warn_if_unsorted, ) ) @unstable() def rolling_sum_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, closed: ClosedInterval = "right", - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling sum based on another column. @@ -6729,6 +6742,10 @@ def rolling_sum_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -6817,23 +6834,27 @@ def rolling_sum_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_sum_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_sum_by( - by, window_size, min_periods, closed, warn_if_unsorted - ) + self._pyexpr.rolling_sum_by(by, window_size, min_periods, closed) ) @unstable() def rolling_std_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, closed: ClosedInterval = "right", ddof: int = 1, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling standard deviation based on another column. @@ -6855,10 +6876,6 @@ def rolling_std_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -6890,6 +6907,10 @@ def rolling_std_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -6978,6 +6999,12 @@ def rolling_std_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_std_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( self._pyexpr.rolling_std_by( @@ -6986,20 +7013,19 @@ def rolling_std_by( min_periods, closed, ddof, - warn_if_unsorted, ) ) @unstable() def rolling_var_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, closed: ClosedInterval = "right", ddof: int = 1, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling variance based on another column. @@ -7020,10 +7046,6 @@ def rolling_var_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -7055,6 +7077,10 @@ def rolling_var_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -7143,6 +7169,12 @@ def rolling_var_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_var_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( self._pyexpr.rolling_var_by( @@ -7151,19 +7183,18 @@ def rolling_var_by( min_periods, closed, ddof, - warn_if_unsorted, ) ) @unstable() def rolling_median_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, closed: ClosedInterval = "right", - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling median based on another column. @@ -7184,10 +7215,6 @@ def rolling_median_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. window_size The length of the window. Can be a dynamic temporal size indicated by a timedelta or the following string language: @@ -7217,6 +7244,10 @@ def rolling_median_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -7281,24 +7312,28 @@ def rolling_median_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_median_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_median_by( - by, window_size, min_periods, closed, warn_if_unsorted - ) + self._pyexpr.rolling_median_by(by, window_size, min_periods, closed) ) @unstable() def rolling_quantile_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, quantile: float, interpolation: RollingInterpolationMethod = "nearest", min_periods: int = 1, closed: ClosedInterval = "right", - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling quantile based on another column. @@ -7319,10 +7354,6 @@ def rolling_quantile_by( ---------- by This column must be of dtype Datetime or Date. - - .. warning:: - The column must be sorted in ascending order. Otherwise, - results will not be correct. quantile Quantile between 0.0 and 1.0. interpolation : {'nearest', 'higher', 'lower', 'midpoint', 'linear'} @@ -7356,6 +7387,10 @@ def rolling_quantile_by( warn_if_unsorted Warn if data is not known to be sorted by `by` column. + .. deprecated:: 0.20.27 + This operation no longer requires sorted data, you can safely remove + the `warn_if_unsorted` argument. + Notes ----- If you want to compute multiple aggregation statistics over the same dynamic @@ -7420,6 +7455,12 @@ def rolling_quantile_by( """ window_size = deprecate_saturating(window_size) window_size = _prepare_rolling_by_window_args(window_size) + if warn_if_unsorted is not None: + issue_deprecation_warning( + "`warn_if_unsorted` is deprecated in `rolling_quantile_by` because it " + "no longer requires sorted data - you can safely remove this argument.", + version="0.20.27", + ) by = parse_as_expression(by) return self._from_pyexpr( self._pyexpr.rolling_quantile_by( @@ -7429,7 +7470,6 @@ def rolling_quantile_by( window_size, min_periods, closed, - warn_if_unsorted, ) ) @@ -7443,7 +7483,7 @@ def rolling_min( center: bool = False, by: str | None = None, closed: ClosedInterval | None = None, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling min (moving min) over the values in this array. @@ -7657,7 +7697,6 @@ def rolling_min( window_size=window_size, # type: ignore[arg-type] min_periods=min_periods, closed=closed or "right", - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -7679,7 +7718,7 @@ def rolling_max( center: bool = False, by: str | None = None, closed: ClosedInterval | None = None, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling max (moving max) over the values in this array. @@ -7919,7 +7958,6 @@ def rolling_max( window_size=window_size, # type: ignore[arg-type] min_periods=min_periods, closed=closed or "right", - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -7941,7 +7979,7 @@ def rolling_mean( center: bool = False, by: str | None = None, closed: ClosedInterval | None = None, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling mean (moving mean) over the values in this array. @@ -8183,7 +8221,6 @@ def rolling_mean( window_size=window_size, # type: ignore[arg-type] min_periods=min_periods, closed=closed or "right", - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -8205,7 +8242,7 @@ def rolling_sum( center: bool = False, by: str | None = None, closed: ClosedInterval | None = None, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Apply a rolling sum (moving sum) over the values in this array. @@ -8445,7 +8482,6 @@ def rolling_sum( window_size=window_size, # type: ignore[arg-type] min_periods=min_periods, closed=closed or "right", - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -8468,7 +8504,7 @@ def rolling_std( by: str | None = None, closed: ClosedInterval | None = None, ddof: int = 1, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling standard deviation. @@ -8708,7 +8744,6 @@ def rolling_std( min_periods=min_periods, closed=closed or "right", ddof=ddof, - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -8732,7 +8767,7 @@ def rolling_var( by: str | None = None, closed: ClosedInterval | None = None, ddof: int = 1, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling variance. @@ -8971,7 +9006,6 @@ def rolling_var( min_periods=min_periods, closed=closed or "right", ddof=ddof, - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -8994,7 +9028,7 @@ def rolling_median( center: bool = False, by: str | None = None, closed: ClosedInterval | None = None, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling median. @@ -9153,7 +9187,6 @@ def rolling_median( window_size=window_size, # type: ignore[arg-type] min_periods=min_periods, closed=closed or "right", - warn_if_unsorted=warn_if_unsorted, ) window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( @@ -9177,7 +9210,7 @@ def rolling_quantile( center: bool = False, by: str | None = None, closed: ClosedInterval | None = None, - warn_if_unsorted: bool = True, + warn_if_unsorted: bool | None = None, ) -> Self: """ Compute a rolling quantile. @@ -9367,7 +9400,6 @@ def rolling_quantile( window_size=window_size, # type: ignore[arg-type] min_periods=min_periods, closed=closed or "right", - warn_if_unsorted=warn_if_unsorted, quantile=quantile, ) window_size = validate_rolling_aggs_arguments(window_size, closed) diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index 44af77b2f4694..d5183c93e9f03 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -28,20 +28,18 @@ impl PyExpr { self.inner.clone().rolling_sum(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed))] fn rolling_sum_by( &self, by: PyExpr, window_size: &str, min_periods: usize, closed: Wrap, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, - warn_if_unsorted, fn_params: None, }; self.inner.clone().rolling_sum_by(by.inner, options).into() @@ -65,20 +63,18 @@ impl PyExpr { self.inner.clone().rolling_min(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed))] fn rolling_min_by( &self, by: PyExpr, window_size: &str, min_periods: usize, closed: Wrap, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, - warn_if_unsorted, fn_params: None, }; self.inner.clone().rolling_min_by(by.inner, options).into() @@ -101,20 +97,18 @@ impl PyExpr { }; self.inner.clone().rolling_max(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed))] fn rolling_max_by( &self, by: PyExpr, window_size: &str, min_periods: usize, closed: Wrap, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, - warn_if_unsorted, fn_params: None, }; self.inner.clone().rolling_max_by(by.inner, options).into() @@ -139,20 +133,18 @@ impl PyExpr { self.inner.clone().rolling_mean(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed))] fn rolling_mean_by( &self, by: PyExpr, window_size: &str, min_periods: usize, closed: Wrap, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, - warn_if_unsorted, fn_params: None, }; @@ -179,7 +171,7 @@ impl PyExpr { self.inner.clone().rolling_std(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, ddof, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed, ddof))] fn rolling_std_by( &self, by: PyExpr, @@ -187,14 +179,12 @@ impl PyExpr { min_periods: usize, closed: Wrap, ddof: u8, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), - warn_if_unsorted, }; self.inner.clone().rolling_std_by(by.inner, options).into() @@ -220,7 +210,7 @@ impl PyExpr { self.inner.clone().rolling_var(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, ddof, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed, ddof))] fn rolling_var_by( &self, by: PyExpr, @@ -228,14 +218,12 @@ impl PyExpr { min_periods: usize, closed: Wrap, ddof: u8, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), - warn_if_unsorted, }; self.inner.clone().rolling_var_by(by.inner, options).into() @@ -259,21 +247,19 @@ impl PyExpr { self.inner.clone().rolling_median(options).into() } - #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + #[pyo3(signature = (by, window_size, min_periods, closed))] fn rolling_median_by( &self, by: PyExpr, window_size: &str, min_periods: usize, closed: Wrap, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, fn_params: None, - warn_if_unsorted, }; self.inner .clone() @@ -305,7 +291,7 @@ impl PyExpr { .into() } - #[pyo3(signature = (by, quantile, interpolation, window_size, min_periods, closed, warn_if_unsorted))] + #[pyo3(signature = (by, quantile, interpolation, window_size, min_periods, closed))] fn rolling_quantile_by( &self, by: PyExpr, @@ -314,14 +300,12 @@ impl PyExpr { window_size: &str, min_periods: usize, closed: Wrap, - warn_if_unsorted: bool, ) -> Self { let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), min_periods, closed_window: closed.0, fn_params: None, - warn_if_unsorted, }; self.inner diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 4bf0e46eda61b..77c5ee125ee8b 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -687,10 +687,23 @@ def test_rolling_aggregations_unsorted_raise_10991() -> None: "val": [1, 2, 3], } ) - with pytest.warns( - UserWarning, match="Series is not known to be sorted by `by` column." - ): - df.with_columns(roll=pl.col("val").rolling_sum_by("dt", "2d")) + result = df.with_columns(roll=pl.col("val").rolling_sum_by("dt", "2d")) + expected = pl.DataFrame( + { + "dt": [datetime(2020, 1, 3), datetime(2020, 1, 1), datetime(2020, 1, 2)], + "val": [1, 2, 3], + "roll": [4, 2, 5], + } + ) + assert_frame_equal(result, expected) + result = ( + df.with_row_index() + .sort("dt") + .with_columns(roll=pl.col("val").rolling_sum_by("dt", "2d")) + .sort("index") + .drop("index") + ) + assert_frame_equal(result, expected) def test_rolling_aggregations_with_over_11225() -> None: @@ -705,16 +718,17 @@ def test_rolling_aggregations_with_over_11225() -> None: df_temporal = df_temporal.sort("group", "date") - result = df_temporal.with_columns( - rolling_row_mean=pl.col("index") - .rolling_mean_by( - by="date", - window_size="2d", - closed="left", - warn_if_unsorted=False, + with pytest.deprecated_call(match="you can safely remove this argument"): + result = df_temporal.with_columns( + rolling_row_mean=pl.col("index") + .rolling_mean_by( + by="date", + window_size="2d", + closed="left", + warn_if_unsorted=False, + ) + .over("group") ) - .over("group") - ) expected = pl.DataFrame( { "index": [0, 1, 2, 3, 4], @@ -1021,6 +1035,14 @@ def test_temporal_windows_size_without_by_15977() -> None: df.select(pl.col("a").rolling_mean("3d")) +def test_by_different_length() -> None: + df = pl.DataFrame({"b": [1]}) + with pytest.raises(InvalidOperationError, match="must be the same length"): + df.select( + pl.col("b").rolling_max_by(pl.Series([datetime(2020, 1, 1)] * 2), "1d") + ) + + def test_incorrect_nulls_16246() -> None: df = pl.concat( [ @@ -1034,6 +1056,16 @@ def test_incorrect_nulls_16246() -> None: assert_frame_equal(result, expected) +def test_rolling_with_dst() -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 10, 26, 1), datetime(2020, 10, 26)], "b": [1, 2]} + ).with_columns(pl.col("a").dt.replace_time_zone("Europe/London")) + with pytest.raises(ComputeError, match="is ambiguous"): + df.select(pl.col("b").rolling_sum_by("a", "1d")) + with pytest.raises(ComputeError, match="is ambiguous"): + df.sort("a").select(pl.col("b").rolling_sum_by("a", "1d")) + + def interval_defs() -> SearchStrategy[ClosedInterval]: closed: list[ClosedInterval] = ["left", "right", "both", "none"] return st.sampled_from(closed) @@ -1160,6 +1192,9 @@ def test_rolling_aggs( result = df.with_columns( getattr(pl.col("value"), func)("ts", window_size=window_size, closed=closed) ) + result_from_unsorted = dataframe.with_columns( + getattr(pl.col("value"), func)("ts", window_size=window_size, closed=closed) + ).sort("ts") expected_dict: dict[str, list[object]] = {"ts": [], "value": []} for ts, _ in df.iter_rows(): @@ -1183,3 +1218,4 @@ def test_rolling_aggs( pl.col("value").cast(result["value"].dtype), ) assert_frame_equal(result, expected) + assert_frame_equal(result_from_unsorted, expected)