diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 297caa7a71f3..6b87dbb177e4 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -88,6 +88,7 @@ take_opt_iter = [] group_by_list = [] # rolling window functions rolling_window = [] +rolling_window_by = [] diagonal_concat = [] dataframe_arithmetic = [] product = [] @@ -135,6 +136,7 @@ docs-selection = [ "dot_product", "row_hash", "rolling_window", + "rolling_window_by", "dtype-categorical", "dtype-decimal", "diagonal_concat", diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index 95679a4dafae..26ea0c4db61f 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -1,6 +1,9 @@ use arrow::legacy::prelude::DynArgs; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -#[derive(Clone)] +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingOptionsFixedWindow { /// The length of the window. pub window_size: usize, @@ -11,9 +14,22 @@ pub struct RollingOptionsFixedWindow { pub weights: Option>, /// Set the labels at the center of the window. pub center: bool, + #[cfg_attr(feature = "serde", serde(skip))] pub fn_params: DynArgs, } +#[cfg(feature = "rolling_window")] +impl PartialEq for RollingOptionsFixedWindow { + fn eq(&self, other: &Self) -> bool { + self.window_size == other.window_size + && self.min_periods == other.min_periods + && self.weights == other.weights + && self.center == other.center + && self.fn_params.is_none() + && other.fn_params.is_none() + } +} + impl Default for RollingOptionsFixedWindow { fn default() -> Self { RollingOptionsFixedWindow { diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 4c696453ed15..be6107c6b209 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -104,7 +104,10 @@ cum_agg = ["polars-plan/cum_agg"] interpolate = ["polars-plan/interpolate"] rolling_window = [ "polars-plan/rolling_window", - "polars-time/rolling_window", +] +rolling_window_by = [ + "polars-plan/rolling_window_by", + "polars-time/rolling_window_by", ] rank = ["polars-plan/rank"] diff = ["polars-plan/diff", "polars-plan/diff"] @@ -292,6 +295,7 @@ features = [ "replace", "rle", "rolling_window", + "rolling_window_by", "round_series", "row_hash", "search_sorted", diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index b986b5924d1b..0cdb926c886a 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -15,8 +15,8 @@ pub use polars_plan::logical_plan::{ }; pub use polars_plan::prelude::UnionArgs; pub(crate) use polars_plan::prelude::*; -#[cfg(feature = "rolling_window")] -pub use polars_time::{prelude::RollingOptions, Duration}; +#[cfg(feature = "rolling_window_by")] +pub use polars_time::Duration; #[cfg(feature = "dynamic_group_by")] pub use polars_time::{DynamicGroupOptions, PolarsTemporalGroupby, RollingGroupOptions}; pub(crate) use polars_utils::arena::{Arena, Node}; diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 85a1177b4a63..0e67cba50566 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -173,14 +173,14 @@ fn test_power_in_agg_list1() -> PolarsResult<()> { .group_by([col("fruits")]) .agg([ col("A") - .rolling_min(RollingOptions { - window_size: Duration::new(1), + .rolling_min(RollingOptionsFixedWindow { + window_size: 1, ..Default::default() }) .alias("input"), col("A") - .rolling_min(RollingOptions { - window_size: Duration::new(1), + .rolling_min(RollingOptionsFixedWindow { + window_size: 1, ..Default::default() }) .pow(2.0) @@ -211,8 +211,8 @@ fn test_power_in_agg_list2() -> PolarsResult<()> { .lazy() .group_by([col("fruits")]) .agg([col("A") - .rolling_min(RollingOptions { - window_size: Duration::new(2), + .rolling_min(RollingOptionsFixedWindow { + window_size: 2, min_periods: 2, ..Default::default() }) diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 168998e8c330..6f7d410dab98 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -105,6 +105,7 @@ log = [] hash = [] reinterpret = ["polars-core/reinterpret"] rolling_window = ["polars-core/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by"] moment = [] mode = [] search_sorted = [] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index ee6d0a2d43ee..bd3a1b9a3626 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -122,7 +122,11 @@ rolling_window = [ "polars-core/rolling_window", "polars-time/rolling_window", "polars-ops/rolling_window", - "polars-time/rolling_window", +] +rolling_window_by = [ + "polars-core/rolling_window_by", + "polars-time/rolling_window_by", + "polars-ops/rolling_window_by", ] rank = ["polars-ops/rank"] diff = ["polars-ops/diff"] @@ -180,6 +184,7 @@ features = [ "temporal", "serde", "rolling_window", + "rolling_window_by", "timezones", "dtype-date", "extract_groups", diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index e45cb5e86313..b9afd4c595d3 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -45,6 +45,8 @@ mod random; mod range; #[cfg(feature = "rolling_window")] pub mod rolling; +#[cfg(feature = "rolling_window_by")] +pub mod rolling_by; #[cfg(feature = "round_series")] mod round; #[cfg(feature = "row_hash")] @@ -96,6 +98,8 @@ pub use self::pow::PowFunction; pub(super) use self::range::RangeFunction; #[cfg(feature = "rolling_window")] pub(super) use self::rolling::RollingFunction; +#[cfg(feature = "rolling_window_by")] +pub(super) use self::rolling_by::RollingFunctionBy; #[cfg(feature = "strings")] pub(crate) use self::strings::StringFunction; #[cfg(feature = "dtype-struct")] @@ -156,6 +160,8 @@ pub enum FunctionExpr { FillNullWithStrategy(FillNullStrategy), #[cfg(feature = "rolling_window")] RollingExpr(RollingFunction), + #[cfg(feature = "rolling_window_by")] + RollingExprBy(RollingFunctionBy), ShiftAndFill, Shift, DropNans, @@ -420,6 +426,10 @@ impl Hash for FunctionExpr { RollingExpr(f) => { f.hash(state); }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(f) => { + f.hash(state); + }, #[cfg(feature = "moment")] Skew(a) => a.hash(state), #[cfg(feature = "moment")] @@ -609,6 +619,8 @@ impl Display for FunctionExpr { FillNull { .. } => "fill_null", #[cfg(feature = "rolling_window")] RollingExpr(func, ..) => return write!(f, "{func}"), + #[cfg(feature = "rolling_window_by")] + RollingExprBy(func, ..) => return write!(f, "{func}"), ShiftAndFill => "shift_and_fill", DropNans => "drop_nans", DropNulls => "drop_nulls", @@ -907,25 +919,31 @@ impl From for SpecialEq> { use RollingFunction::*; match f { Min(options) => map!(rolling::rolling_min, options.clone()), - MinBy(options) => map_as_slice!(rolling::rolling_min_by, options.clone()), Max(options) => map!(rolling::rolling_max, options.clone()), - MaxBy(options) => map_as_slice!(rolling::rolling_max_by, options.clone()), Mean(options) => map!(rolling::rolling_mean, options.clone()), - MeanBy(options) => map_as_slice!(rolling::rolling_mean_by, options.clone()), Sum(options) => map!(rolling::rolling_sum, options.clone()), - SumBy(options) => map_as_slice!(rolling::rolling_sum_by, options.clone()), Quantile(options) => map!(rolling::rolling_quantile, options.clone()), - QuantileBy(options) => { - map_as_slice!(rolling::rolling_quantile_by, options.clone()) - }, Var(options) => map!(rolling::rolling_var, options.clone()), - VarBy(options) => map_as_slice!(rolling::rolling_var_by, options.clone()), Std(options) => map!(rolling::rolling_std, options.clone()), - StdBy(options) => map_as_slice!(rolling::rolling_std_by, options.clone()), #[cfg(feature = "moment")] Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias), } }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(f) => { + use RollingFunctionBy::*; + match f { + MinBy(options) => map_as_slice!(rolling_by::rolling_min_by, options.clone()), + MaxBy(options) => map_as_slice!(rolling_by::rolling_max_by, options.clone()), + MeanBy(options) => map_as_slice!(rolling_by::rolling_mean_by, options.clone()), + SumBy(options) => map_as_slice!(rolling_by::rolling_sum_by, options.clone()), + QuantileBy(options) => { + map_as_slice!(rolling_by::rolling_quantile_by, options.clone()) + }, + VarBy(options) => map_as_slice!(rolling_by::rolling_var_by, options.clone()), + StdBy(options) => map_as_slice!(rolling_by::rolling_std_by, options.clone()), + } + }, #[cfg(feature = "hist")] Hist { bin_count, diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index f1ae64c5f792..9302ab4a1ad7 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -5,20 +5,13 @@ use super::*; #[derive(Clone, PartialEq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum RollingFunction { - Min(RollingOptions), - MinBy(RollingOptions), - Max(RollingOptions), - MaxBy(RollingOptions), - Mean(RollingOptions), - MeanBy(RollingOptions), - Sum(RollingOptions), - SumBy(RollingOptions), - Quantile(RollingOptions), - QuantileBy(RollingOptions), - Var(RollingOptions), - VarBy(RollingOptions), - Std(RollingOptions), - StdBy(RollingOptions), + Min(RollingOptionsFixedWindow), + Max(RollingOptionsFixedWindow), + Mean(RollingOptionsFixedWindow), + Sum(RollingOptionsFixedWindow), + Quantile(RollingOptionsFixedWindow), + Var(RollingOptionsFixedWindow), + Std(RollingOptionsFixedWindow), #[cfg(feature = "moment")] Skew(usize, bool), } @@ -29,19 +22,12 @@ impl Display for RollingFunction { let name = match self { Min(_) => "rolling_min", - MinBy(_) => "rolling_min_by", Max(_) => "rolling_max", - MaxBy(_) => "rolling_max_by", Mean(_) => "rolling_mean", - MeanBy(_) => "rolling_mean_by", Sum(_) => "rolling_sum", - SumBy(_) => "rolling_sum_by", Quantile(_) => "rolling_quantile", - QuantileBy(_) => "rolling_quantile_by", Var(_) => "rolling_var", - VarBy(_) => "rolling_var_by", Std(_) => "rolling_std", - StdBy(_) => "rolling_std_by", #[cfg(feature = "moment")] Skew(..) => "rolling_skew", }; @@ -66,123 +52,35 @@ impl Hash for RollingFunction { } } -fn convert<'a>( - f: impl Fn(RollingOptionsImpl) -> PolarsResult + 'a, - ss: &'a [Series], - expr_name: &'static str, -) -> impl Fn(RollingOptions) -> PolarsResult + 'a { - move |options| { - let mut by = ss[1].clone(); - by = by.rechunk(); - - let (by, tz) = match by.dtype() { - DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), - DataType::Date => ( - by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, - &None, - ), - dt => polars_bail!(InvalidOperation: - "in `{}` operation, `by` argument of dtype `{}` is not supported (expected `{}`)", - expr_name, - dt, - "date/datetime"), - }; - if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted { - polars_warn!(format!( - "Series is not known to be sorted by `by` column in {} operation.\n\ - \n\ - To silence this warning, you may want to try:\n\ - - sorting your data by your `by` column beforehand;\n\ - - setting `.set_sorted()` if you already know your data is sorted;\n\ - - passing `warn_if_unsorted=False` if this warning is a false-positive\n \ - (this is known to happen when combining rolling aggregations with `over`);\n\n\ - before passing calling the rolling aggregation function.\n", - expr_name - )); - } - let by = by.datetime().unwrap(); - let by_values = by.cont_slice().map_err(|_| { - polars_err!( - ComputeError: - "`by` column should not have null values in 'rolling by' expression" - ) - })?; - let tu = by.time_unit(); - - let options = RollingOptionsImpl { - window_size: options.window_size, - min_periods: options.min_periods, - weights: options.weights, - center: options.center, - by: Some(by_values), - tu: Some(tu), - tz: tz.as_ref(), - closed_window: options.closed_window, - fn_params: options.fn_params.clone(), - }; - - f(options) - } -} - -pub(super) fn rolling_min(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_min(options.into()) -} - -pub(super) fn rolling_min_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_min(options), s, "rolling_min")(options) -} - -pub(super) fn rolling_max(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_max(options.into()) -} - -pub(super) fn rolling_max_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_max(options), s, "rolling_max")(options) -} - -pub(super) fn rolling_mean(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_mean(options.into()) -} - -pub(super) fn rolling_mean_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_mean(options), s, "rolling_mean")(options) -} - -pub(super) fn rolling_sum(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_sum(options.into()) -} - -pub(super) fn rolling_sum_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_sum(options), s, "rolling_sum")(options) +pub(super) fn rolling_min(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_min(options) } -pub(super) fn rolling_quantile(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_quantile(options.into()) +pub(super) fn rolling_max(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_max(options) } -pub(super) fn rolling_quantile_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert( - |options| s[0].rolling_quantile(options), - s, - "rolling_quantile", - )(options) +pub(super) fn rolling_mean(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_mean(options) } -pub(super) fn rolling_var(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_var(options.into()) +pub(super) fn rolling_sum(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_sum(options) } -pub(super) fn rolling_var_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_var(options), s, "rolling_var")(options) +pub(super) fn rolling_quantile( + s: &Series, + options: RollingOptionsFixedWindow, +) -> PolarsResult { + s.rolling_quantile(options) } -pub(super) fn rolling_std(s: &Series, options: RollingOptions) -> PolarsResult { - s.rolling_std(options.into()) +pub(super) fn rolling_var(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_var(options) } -pub(super) fn rolling_std_by(s: &[Series], options: RollingOptions) -> PolarsResult { - convert(|options| s[0].rolling_std(options), s, "rolling_std")(options) +pub(super) fn rolling_std(s: &Series, options: RollingOptionsFixedWindow) -> PolarsResult { + s.rolling_std(options) } #[cfg(feature = "moment")] diff --git a/crates/polars-plan/src/dsl/function_expr/rolling_by.rs b/crates/polars-plan/src/dsl/function_expr/rolling_by.rs new file mode 100644 index 000000000000..c2b3510281f2 --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/rolling_by.rs @@ -0,0 +1,88 @@ +use polars_time::chunkedarray::*; + +use super::*; + +#[derive(Clone, PartialEq, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum RollingFunctionBy { + MinBy(RollingOptionsDynamicWindow), + MaxBy(RollingOptionsDynamicWindow), + MeanBy(RollingOptionsDynamicWindow), + SumBy(RollingOptionsDynamicWindow), + QuantileBy(RollingOptionsDynamicWindow), + VarBy(RollingOptionsDynamicWindow), + StdBy(RollingOptionsDynamicWindow), +} + +impl Display for RollingFunctionBy { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RollingFunctionBy::*; + + let name = match self { + MinBy(_) => "rolling_min_by", + MaxBy(_) => "rolling_max_by", + MeanBy(_) => "rolling_mean_by", + SumBy(_) => "rolling_sum_by", + QuantileBy(_) => "rolling_quantile_by", + VarBy(_) => "rolling_var_by", + StdBy(_) => "rolling_std_by", + }; + + write!(f, "{name}") + } +} + +impl Hash for RollingFunctionBy { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + } +} + +pub(super) fn rolling_min_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_min_by(&s[1], options) +} + +pub(super) fn rolling_max_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_max_by(&s[1], options) +} + +pub(super) fn rolling_mean_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_mean_by(&s[1], options) +} + +pub(super) fn rolling_sum_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_sum_by(&s[1], options) +} + +pub(super) fn rolling_quantile_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_quantile_by(&s[1], options) +} + +pub(super) fn rolling_var_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_var_by(&s[1], options) +} + +pub(super) fn rolling_std_by( + s: &[Series], + options: RollingOptionsDynamicWindow, +) -> PolarsResult { + s[0].rolling_std_by(&s[1], options) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 830891fea1cb..e301557a247a 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -64,15 +64,20 @@ impl FunctionExpr { RollingExpr(rolling_func, ..) => { use RollingFunction::*; match rolling_func { - Min(_) | MinBy(_) | Max(_) | MaxBy(_) | Sum(_) | SumBy(_) => { - mapper.with_same_dtype() - }, - Mean(_) | MeanBy(_) | Quantile(_) | QuantileBy(_) | Var(_) | VarBy(_) - | Std(_) | StdBy(_) => mapper.map_to_float_dtype(), + Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(), + Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(), #[cfg(feature = "moment")] Skew(..) => mapper.map_to_float_dtype(), } }, + #[cfg(feature = "rolling_window_by")] + RollingExprBy(rolling_func, ..) => { + use RollingFunctionBy::*; + match rolling_func { + MinBy(_) | MaxBy(_) | SumBy(_) => mapper.with_same_dtype(), + MeanBy(_) | QuantileBy(_) | VarBy(_) | StdBy(_) => mapper.map_to_float_dtype(), + } + }, ShiftAndFill => mapper.with_same_dtype(), DropNans => mapper.with_same_dtype(), DropNulls => mapper.with_same_dtype(), diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index a41a8c8621a2..651365091cbe 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -73,8 +73,8 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E #[cfg(feature = "rolling_window")] pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804 - let rolling_options = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: options.min_periods as usize, ..Default::default() }; @@ -85,8 +85,8 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { let var_x = x.clone().rolling_var(rolling_options.clone()); let var_y = y.clone().rolling_var(rolling_options); - let rolling_options_count = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options_count = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: 0, ..Default::default() }; @@ -104,8 +104,8 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { #[cfg(feature = "rolling_window")] pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700 - let rolling_options = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: options.min_periods as usize, ..Default::default() }; @@ -113,8 +113,8 @@ pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); let mean_x = x.clone().rolling_mean(rolling_options.clone()); let mean_y = y.clone().rolling_mean(rolling_options); - let rolling_options_count = RollingOptions { - window_size: Duration::new(options.window_size as i64), + let rolling_options_count = RollingOptionsFixedWindow { + window_size: options.window_size as usize, min_periods: 0, ..Default::default() }; diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 015e999851df..3c4b96130583 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -3,12 +3,12 @@ #[cfg(feature = "dtype-categorical")] pub mod cat; -#[cfg(feature = "rolling_window")] +#[cfg(any(feature = "rolling_window", feature = "rolling_window_by"))] use std::any::Any; #[cfg(feature = "dtype-categorical")] pub use cat::*; -#[cfg(feature = "rolling_window")] +#[cfg(feature = "rolling_window_by")] pub(crate) use polars_time::prelude::*; mod arithmetic; @@ -1237,64 +1237,128 @@ impl Expr { self.apply_private(FunctionExpr::Interpolate(method)) } + #[cfg(feature = "rolling_window_by")] + #[allow(clippy::type_complexity)] + fn finish_rolling_by( + self, + by: Expr, + options: RollingOptionsDynamicWindow, + rolling_function_by: fn(RollingOptionsDynamicWindow) -> RollingFunctionBy, + ) -> Expr { + self.apply_many_private( + FunctionExpr::RollingExprBy(rolling_function_by(options)), + &[by], + false, + false, + ) + } + #[cfg(feature = "rolling_window")] #[allow(clippy::type_complexity)] fn finish_rolling( self, - options: RollingOptions, - rolling_function: fn(RollingOptions) -> RollingFunction, - rolling_function_by: fn(RollingOptions) -> RollingFunction, + options: RollingOptionsFixedWindow, + rolling_function: fn(RollingOptionsFixedWindow) -> RollingFunction, ) -> Expr { - if let Some(ref by) = options.by { - let name = by.clone(); - self.apply_many_private( - FunctionExpr::RollingExpr(rolling_function_by(options)), - &[col(&name)], - false, - false, - ) - } else { - self.apply_private(FunctionExpr::RollingExpr(rolling_function(options))) - } + self.apply_private(FunctionExpr::RollingExpr(rolling_function(options))) + } + + /// Apply a rolling minimum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_min_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MinBy) + } + + /// Apply a rolling maximum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_max_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MaxBy) + } + + /// Apply a rolling mean based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_mean_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::MeanBy) + } + + /// Apply a rolling sum based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_sum_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::SumBy) + } + + /// Apply a rolling quantile based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_quantile_by( + self, + by: Expr, + interpol: QuantileInterpolOptions, + quantile: f64, + mut options: RollingOptionsDynamicWindow, + ) -> Expr { + options.fn_params = Some(Arc::new(RollingQuantileParams { + prob: quantile, + interpol, + }) as Arc); + + self.finish_rolling_by(by, options, RollingFunctionBy::QuantileBy) + } + + /// Apply a rolling variance based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_var_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::VarBy) + } + + /// Apply a rolling std-dev based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_std_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.finish_rolling_by(by, options, RollingFunctionBy::StdBy) + } + + /// Apply a rolling median based on another column. + #[cfg(feature = "rolling_window_by")] + pub fn rolling_median_by(self, by: Expr, options: RollingOptionsDynamicWindow) -> Expr { + self.rolling_quantile_by(by, QuantileInterpolOptions::Linear, 0.5, options) } /// Apply a rolling minimum. /// /// See: [`RollingAgg::rolling_min`] #[cfg(feature = "rolling_window")] - pub fn rolling_min(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Min, RollingFunction::MinBy) + pub fn rolling_min(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Min) } /// Apply a rolling maximum. /// /// See: [`RollingAgg::rolling_max`] #[cfg(feature = "rolling_window")] - pub fn rolling_max(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Max, RollingFunction::MaxBy) + pub fn rolling_max(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Max) } /// Apply a rolling mean. /// /// See: [`RollingAgg::rolling_mean`] #[cfg(feature = "rolling_window")] - pub fn rolling_mean(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Mean, RollingFunction::MeanBy) + pub fn rolling_mean(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Mean) } /// Apply a rolling sum. /// /// See: [`RollingAgg::rolling_sum`] #[cfg(feature = "rolling_window")] - pub fn rolling_sum(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Sum, RollingFunction::SumBy) + pub fn rolling_sum(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Sum) } /// Apply a rolling median. /// /// See: [`RollingAgg::rolling_median`] #[cfg(feature = "rolling_window")] - pub fn rolling_median(self, options: RollingOptions) -> Expr { + pub fn rolling_median(self, options: RollingOptionsFixedWindow) -> Expr { self.rolling_quantile(QuantileInterpolOptions::Linear, 0.5, options) } @@ -1306,30 +1370,26 @@ impl Expr { self, interpol: QuantileInterpolOptions, quantile: f64, - mut options: RollingOptions, + mut options: RollingOptionsFixedWindow, ) -> Expr { options.fn_params = Some(Arc::new(RollingQuantileParams { prob: quantile, interpol, }) as Arc); - self.finish_rolling( - options, - RollingFunction::Quantile, - RollingFunction::QuantileBy, - ) + self.finish_rolling(options, RollingFunction::Quantile) } /// Apply a rolling variance. #[cfg(feature = "rolling_window")] - pub fn rolling_var(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Var, RollingFunction::VarBy) + pub fn rolling_var(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Var) } /// Apply a rolling std-dev. #[cfg(feature = "rolling_window")] - pub fn rolling_std(self, options: RollingOptions) -> Expr { - self.finish_rolling(options, RollingFunction::Std, RollingFunction::StdBy) + pub fn rolling_std(self, options: RollingOptionsFixedWindow) -> Expr { + self.finish_rolling(options, RollingFunction::Std) } /// Apply a rolling skew. diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index 2925de13c869..9e0773ecd2c6 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -32,7 +32,8 @@ dtype-date = ["polars-core/dtype-date", "temporal"] dtype-datetime = ["polars-core/dtype-datetime", "temporal"] dtype-time = ["polars-core/dtype-time", "temporal"] dtype-duration = ["polars-core/dtype-duration", "temporal"] -rolling_window = ["polars-core/rolling_window", "dtype-duration"] +rolling_window = ["polars-core/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by", "dtype-duration"] fmt = ["polars-core/fmt"] serde = ["dep:serde", "smartstring/serde"] temporal = ["polars-core/temporal"] diff --git a/crates/polars-time/src/chunkedarray/mod.rs b/crates/polars-time/src/chunkedarray/mod.rs index 4c2fb9cbf505..e61031d46ed1 100644 --- a/crates/polars-time/src/chunkedarray/mod.rs +++ b/crates/polars-time/src/chunkedarray/mod.rs @@ -6,7 +6,7 @@ mod datetime; #[cfg(feature = "dtype-duration")] mod duration; mod kernels; -#[cfg(feature = "rolling_window")] +#[cfg(any(feature = "rolling_window", feature = "rolling_window_by"))] mod rolling_window; pub mod string; #[cfg(feature = "dtype-time")] @@ -22,7 +22,7 @@ pub use datetime::DatetimeMethods; pub use duration::DurationMethods; use kernels::*; use polars_core::prelude::*; -#[cfg(feature = "rolling_window")] +#[cfg(any(feature = "rolling_window", feature = "rolling_window_by"))] pub use rolling_window::*; pub use string::StringMethods; #[cfg(feature = "dtype-time")] diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 1e6eb024919d..5feb3f9f99cb 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -1,13 +1,15 @@ +use polars_core::series::IsSorted; use polars_core::{with_match_physical_float_polars_type, with_match_physical_numeric_polars_type}; use super::*; use crate::prelude::*; use crate::series::AsSeries; +#[cfg(feature = "rolling_window")] #[allow(clippy::type_complexity)] fn rolling_agg( ca: &ChunkedArray, - options: RollingOptionsImpl, + options: RollingOptionsFixedWindow, rolling_agg_fn: &dyn Fn( &[T::Native], usize, @@ -24,79 +26,140 @@ fn rolling_agg( Option<&[f64]>, DynArgs, ) -> ArrayRef, - rolling_agg_fn_dynamic: Option< - &dyn Fn( - &[T::Native], - Duration, - &[i64], - ClosedWindow, - usize, - TimeUnit, - Option<&TimeZone>, - DynArgs, - ) -> PolarsResult, - >, ) -> PolarsResult where T: PolarsNumericType, { + polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`"); if ca.is_empty() { return Ok(Series::new_empty(ca.name(), ca.dtype())); } let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); - // "5i" is a window size of 5, e.g. fixed - let arr = if options.by.is_none() { - let options: RollingOptionsFixedWindow = options.try_into()?; - Ok(match ca.null_count() { - 0 => rolling_agg_fn( - arr.values().as_slice(), - options.window_size, - options.min_periods, - options.center, - options.weights.as_deref(), - options.fn_params, - )?, - _ => rolling_agg_fn_nulls( - arr, - options.window_size, - options.min_periods, - options.center, - options.weights.as_deref(), - options.fn_params, - ), - }) - } else { - let options: RollingOptionsDynamicWindow = options.try_into()?; - if arr.null_count() > 0 { - polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") - } - let values = arr.values().as_slice(); - let tu = options.tu.expect("time_unit was set in `convert` function"); - let by = options.by; - let func = rolling_agg_fn_dynamic.expect("rolling_agg_fn_dynamic must have been passed"); - - func( - values, + let arr = match ca.null_count() { + 0 => rolling_agg_fn( + arr.values().as_slice(), + options.window_size, + options.min_periods, + options.center, + options.weights.as_deref(), + options.fn_params, + )?, + _ => rolling_agg_fn_nulls( + arr, options.window_size, - by, - options.closed_window, options.min_periods, - tu, - options.tz, + options.center, + options.weights.as_deref(), options.fn_params, + ), + }; + Series::try_from((ca.name(), arr)) +} + +#[cfg(feature = "rolling_window_by")] +#[allow(clippy::type_complexity)] +fn rolling_agg_by( + ca: &ChunkedArray, + by: &Series, + options: RollingOptionsDynamicWindow, + rolling_agg_fn_dynamic: &dyn Fn( + &[T::Native], + Duration, + &[i64], + ClosedWindow, + usize, + TimeUnit, + Option<&TimeZone>, + DynArgs, + ) -> PolarsResult, +) -> PolarsResult +where + T: PolarsNumericType, +{ + if ca.is_empty() { + return Ok(Series::new_empty(ca.name(), ca.dtype())); + } + let ca = ca.rechunk(); + ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?; + polars_ensure!(options.window_size.duration_ns()>0 && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); + if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted { + polars_warn!(format!( + "Series is not known to be sorted by `by` column in `rolling_*_by` operation.\n\ + \n\ + To silence this warning, you may want to try:\n\ + - sorting your data by your `by` column beforehand;\n\ + - setting `.set_sorted()` if you already know your data is sorted;\n\ + - passing `warn_if_unsorted=False` if this warning is a false-positive\n \ + (this is known to happen when combining rolling aggregations with `over`);\n\n\ + before passing calling the rolling aggregation function.\n", + )); + } + let (by, tz) = match by.dtype() { + DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), + DataType::Date => ( + by.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, + &None, + ), + dt => polars_bail!(InvalidOperation: + "in `rolling_*_by` operation, `by` argument of dtype `{}` is not supported (expected `{}`)", + dt, + "date/datetime"), + }; + let by = by.datetime().unwrap(); + let by_values = by.cont_slice().map_err(|_| { + polars_err!( + ComputeError: + "`by` column should not have null values in 'rolling by' expression" ) - }?; + })?; + let tu = by.time_unit(); + + let arr = ca.downcast_iter().next().unwrap(); + if arr.null_count() > 0 { + polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'") + } + let values = arr.values().as_slice(); + let func = rolling_agg_fn_dynamic; + + let arr = func( + values, + options.window_size, + by_values, + options.closed_window, + options.min_periods, + tu, + tz.as_ref(), + options.fn_params, + )?; Series::try_from((ca.name(), arr)) } pub trait SeriesOpsTime: AsSeries { + /// Apply a rolling mean to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_mean_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_mean, + ) + }) + } /// Apply a rolling mean to a Series. /// /// See: [`RollingAgg::rolling_mean`] #[cfg(feature = "rolling_window")] - fn rolling_mean(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_mean(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let s = self.as_series().to_float()?; with_match_physical_float_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); @@ -105,13 +168,31 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_mean, &rolling::nulls::rolling_mean, - Some(&super::rolling_kernels::no_nulls::rolling_mean), ) }) } + /// Apply a rolling sum to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_sum_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().clone(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_sum, + ) + }) + } + /// Apply a rolling sum to a Series. #[cfg(feature = "rolling_window")] - fn rolling_sum(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_sum(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let mut s = self.as_series().clone(); if options.weights.is_some() { s = s.to_float()?; @@ -124,14 +205,32 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_sum, &rolling::nulls::rolling_sum, - Some(&super::rolling_kernels::no_nulls::rolling_sum), ) }) } + /// Apply a rolling quantile to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_quantile_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().to_float()?; + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_quantile, + ) + }) + } + /// Apply a rolling quantile to a Series. #[cfg(feature = "rolling_window")] - fn rolling_quantile(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_quantile(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let s = self.as_series().to_float()?; with_match_physical_float_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); @@ -140,14 +239,32 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_quantile, &rolling::nulls::rolling_quantile, - Some(&super::rolling_kernels::no_nulls::rolling_quantile), ) }) } + /// Apply a rolling min to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_min_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().clone(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_min, + ) + }) + } + /// Apply a rolling min to a Series. #[cfg(feature = "rolling_window")] - fn rolling_min(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_min(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let mut s = self.as_series().clone(); if options.weights.is_some() { s = s.to_float()?; @@ -160,13 +277,32 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_min, &rolling::nulls::rolling_min, - Some(&super::rolling_kernels::no_nulls::rolling_min), ) }) } + + /// Apply a rolling max to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_max_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().clone(); + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + rolling_agg_by( + ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_max, + ) + }) + } + /// Apply a rolling max to a Series. #[cfg(feature = "rolling_window")] - fn rolling_max(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_max(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let mut s = self.as_series().clone(); if options.weights.is_some() { s = s.to_float()?; @@ -179,14 +315,48 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_max, &rolling::nulls::rolling_max, - Some(&super::rolling_kernels::no_nulls::rolling_max), + ) + }) + } + + /// Apply a rolling variance to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_var_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + let s = self.as_series().to_float()?; + + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let mut ca = ca.clone(); + + if let Some(idx) = ca.first_non_null() { + let k = ca.get(idx).unwrap(); + // TODO! remove this! + // This is a temporary hack to improve numeric stability. + // var(X) = var(X - k) + // This is temporary as we will rework the rolling methods + // the 100.0 absolute boundary is arbitrarily chosen. + // the algorithm will square numbers, so it loses precision rapidly + if k.abs() > 100.0 { + ca = ca - k; + } + } + + rolling_agg_by( + &ca, + by, + options, + &super::rolling_kernels::no_nulls::rolling_var, ) }) } /// Apply a rolling variance to a Series. #[cfg(feature = "rolling_window")] - fn rolling_var(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_var(&self, options: RollingOptionsFixedWindow) -> PolarsResult { let s = self.as_series().to_float()?; with_match_physical_float_polars_type!(s.dtype(), |$T| { @@ -211,14 +381,36 @@ pub trait SeriesOpsTime: AsSeries { options, &rolling::no_nulls::rolling_var, &rolling::nulls::rolling_var, - Some(&super::rolling_kernels::no_nulls::rolling_var), ) }) } + /// Apply a rolling std_dev to a Series based on another Series. + #[cfg(feature = "rolling_window_by")] + fn rolling_std_by( + &self, + by: &Series, + options: RollingOptionsDynamicWindow, + ) -> PolarsResult { + self.rolling_var_by(by, options).map(|mut s| { + match s.dtype().clone() { + DataType::Float32 => { + let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); + ca.apply_mut(|v| v.powf(0.5)) + }, + DataType::Float64 => { + let ca: &mut ChunkedArray = s._get_inner_mut().as_mut(); + ca.apply_mut(|v| v.powf(0.5)) + }, + _ => unreachable!(), + } + s + }) + } + /// Apply a rolling std_dev to a Series. #[cfg(feature = "rolling_window")] - fn rolling_std(&self, options: RollingOptionsImpl) -> PolarsResult { + fn rolling_std(&self, options: RollingOptionsFixedWindow) -> PolarsResult { self.rolling_var(options).map(|mut s| { match s.dtype().clone() { DataType::Float32 => { diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index d5ae53e1459f..0b2909b5dda4 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -1,4 +1,5 @@ mod dispatch; +#[cfg(feature = "rolling_window_by")] mod rolling_kernels; use arrow::array::{Array, ArrayRef, PrimitiveArray}; @@ -12,20 +13,13 @@ use crate::prelude::*; #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct RollingOptions { +pub struct RollingOptionsDynamicWindow { /// The length of the window. pub window_size: Duration, /// Amount of elements in the window that should be filled before computing a result. pub min_periods: usize, - /// An optional slice with the same length as the window that will be multiplied - /// elementwise with the values in the window. - pub weights: Option>, - /// Set the labels at the center of the window. - pub center: bool, - /// Compute the rolling aggregates with a window defined by a time column - pub by: Option, - /// The closed window of that time window if given - pub closed_window: Option, + /// Which side windows should be closed. + pub closed_window: ClosedWindow, /// Optional parameters for the rolling function #[cfg_attr(feature = "serde", serde(skip))] pub fn_params: DynArgs, @@ -33,152 +27,14 @@ pub struct RollingOptions { pub warn_if_unsorted: bool, } -impl Default for RollingOptions { - fn default() -> Self { - RollingOptions { - window_size: Duration::parse("3i"), - min_periods: 1, - weights: None, - center: false, - by: None, - closed_window: None, - fn_params: None, - warn_if_unsorted: true, - } - } -} - -#[cfg(feature = "rolling_window")] -impl PartialEq for RollingOptions { +#[cfg(feature = "rolling_window_by")] +impl PartialEq for RollingOptionsDynamicWindow { fn eq(&self, other: &Self) -> bool { self.window_size == other.window_size && self.min_periods == other.min_periods - && self.weights == other.weights - && self.center == other.center - && self.by == other.by && self.closed_window == other.closed_window + && self.warn_if_unsorted == other.warn_if_unsorted && self.fn_params.is_none() && other.fn_params.is_none() } } - -#[derive(Clone)] -pub struct RollingOptionsImpl<'a> { - /// The length of the window. - pub window_size: Duration, - /// Amount of elements in the window that should be filled before computing a result. - pub min_periods: usize, - /// An optional slice with the same length as the window that will be multiplied - /// elementwise with the values in the window. - pub weights: Option>, - /// Set the labels at the center of the window. - pub center: bool, - pub by: Option<&'a [i64]>, - pub tu: Option, - pub tz: Option<&'a TimeZone>, - pub closed_window: Option, - pub fn_params: DynArgs, -} - -impl From for RollingOptionsImpl<'static> { - fn from(options: RollingOptions) -> Self { - RollingOptionsImpl { - window_size: options.window_size, - min_periods: options.min_periods, - weights: options.weights, - center: options.center, - by: None, - tu: None, - tz: None, - closed_window: options.closed_window, - fn_params: options.fn_params, - } - } -} - -impl Default for RollingOptionsImpl<'static> { - fn default() -> Self { - RollingOptionsImpl { - window_size: Duration::parse("3i"), - min_periods: 1, - weights: None, - center: false, - by: None, - tu: None, - tz: None, - closed_window: None, - fn_params: None, - } - } -} - -impl<'a> TryFrom> for RollingOptionsFixedWindow { - type Error = PolarsError; - fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult { - polars_ensure!( - options.window_size.parsed_int, - InvalidOperation: "if `window_size` is a temporal window (e.g. '1d', '2h, ...), then the `by` argument must be passed" - ); - polars_ensure!( - options.closed_window.is_none(), - InvalidOperation: "`closed_window` is not supported for fixed window size rolling aggregations, \ - consider using DataFrame.rolling for greater flexibility", - ); - let window_size = options.window_size.nanoseconds() as usize; - check_input(window_size, options.min_periods)?; - Ok(RollingOptionsFixedWindow { - window_size, - min_periods: options.min_periods, - weights: options.weights, - center: options.center, - fn_params: options.fn_params, - }) - } -} - -/// utility -fn check_input(window_size: usize, min_periods: usize) -> PolarsResult<()> { - polars_ensure!( - min_periods <= window_size, - ComputeError: "`min_periods` should be <= `window_size`", - ); - Ok(()) -} - -#[derive(Clone)] -pub struct RollingOptionsDynamicWindow<'a> { - /// The length of the window. - pub window_size: Duration, - /// Amount of elements in the window that should be filled before computing a result. - pub min_periods: usize, - pub by: &'a [i64], - pub tu: Option, - pub tz: Option<&'a TimeZone>, - pub closed_window: ClosedWindow, - pub fn_params: DynArgs, -} - -impl<'a> TryFrom> for RollingOptionsDynamicWindow<'a> { - type Error = PolarsError; - fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult { - let duration = options.window_size; - polars_ensure!(duration.duration_ns() > 0 && !duration.negative, ComputeError:"window size should be strictly positive"); - polars_ensure!( - options.weights.is_none(), - InvalidOperation: "`weights` is not supported in 'rolling_*(..., by=...)' expression" - ); - polars_ensure!( - !options.window_size.parsed_int, - InvalidOperation: "if `by` argument is passed, then `window_size` must be a temporal window (e.g. '1d' or '2h', not '3i')" - ); - Ok(RollingOptionsDynamicWindow { - window_size: options.window_size, - min_periods: options.min_periods, - by: options.by.expect("by must have been set to get here"), - tu: options.tu, - tz: options.tz, - closed_window: options.closed_window.unwrap_or(ClosedWindow::Right), - fn_params: options.fn_params, - }) - } -} diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index c7cb2429fa22..7b48db38c8e6 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -443,7 +443,7 @@ pub(crate) fn group_by_values_iter_lookahead( }) } -#[cfg(feature = "rolling_window")] +#[cfg(feature = "rolling_window_by")] #[inline] pub(crate) fn group_by_values_iter( period: Duration, diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 9056f42abfaa..0c16fa597cfa 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -195,7 +195,8 @@ reinterpret = ["polars-core/reinterpret", "polars-lazy?/reinterpret", "polars-op repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"] replace = ["polars-ops/replace", "polars-lazy?/replace"] rle = ["polars-lazy?/rle"] -rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] +rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window"] +rolling_window_by = ["polars-core/rolling_window_by", "polars-lazy?/rolling_window_by", "polars-time/rolling_window_by"] round_series = ["polars-ops/round_series", "polars-lazy?/round_series"] row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"] search_sorted = ["polars-lazy?/search_sorted"] @@ -366,6 +367,7 @@ docs-selection = [ "take_opt_iter", "cum_agg", "rolling_window", + "rolling_window_by", "interpolate", "diff", "rank", diff --git a/crates/polars/tests/it/core/rolling_window.rs b/crates/polars/tests/it/core/rolling_window.rs index 17270932bca2..a374523e1454 100644 --- a/crates/polars/tests/it/core/rolling_window.rs +++ b/crates/polars/tests/it/core/rolling_window.rs @@ -4,8 +4,8 @@ use super::*; fn test_rolling() { let s = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]).into_series(); let a = s - .rolling_sum(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_sum(RollingOptionsFixedWindow { + window_size: 2, min_periods: 1, ..Default::default() }) @@ -20,8 +20,8 @@ fn test_rolling() { .collect::>() ); let a = s - .rolling_min(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_min(RollingOptionsFixedWindow { + window_size: 2, min_periods: 1, ..Default::default() }) @@ -36,8 +36,8 @@ fn test_rolling() { .collect::>() ); let a = s - .rolling_max(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_max(RollingOptionsFixedWindow { + window_size: 2, weights: Some(vec![1., 1.]), min_periods: 1, ..Default::default() @@ -59,8 +59,8 @@ fn test_rolling() { fn test_rolling_min_periods() { let s = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]).into_series(); let a = s - .rolling_max(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_max(RollingOptionsFixedWindow { + window_size: 2, min_periods: 2, ..Default::default() }) @@ -87,8 +87,8 @@ fn test_rolling_mean() { // check err on wrong input assert!(s - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(1), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 1, min_periods: 2, ..Default::default() }) @@ -96,8 +96,8 @@ fn test_rolling_mean() { // validate that we divide by the proper window length. (same as pandas) let a = s - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(3), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 3, min_periods: 1, center: false, ..Default::default() @@ -119,8 +119,8 @@ fn test_rolling_mean() { // check centered rolling window let a = s - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(3), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 3, min_periods: 1, center: true, ..Default::default() @@ -144,8 +144,8 @@ fn test_rolling_mean() { let ca = Int32Chunked::from_slice("", &[1, 8, 6, 2, 16, 10]); let out = ca .into_series() - .rolling_mean(RollingOptionsImpl { - window_size: Duration::new(2), + .rolling_mean(RollingOptionsFixedWindow { + window_size: 2, weights: None, min_periods: 2, center: false, @@ -211,8 +211,8 @@ fn test_rolling_var() { .into_series(); // window larger than array assert_eq!( - s.rolling_var(RollingOptionsImpl { - window_size: Duration::new(10), + s.rolling_var(RollingOptionsFixedWindow { + window_size: 10, min_periods: 10, ..Default::default() }) @@ -221,8 +221,8 @@ fn test_rolling_var() { s.len() ); - let options = RollingOptionsImpl { - window_size: Duration::new(3), + let options = RollingOptionsFixedWindow { + window_size: 3, min_periods: 3, ..Default::default() }; @@ -252,8 +252,8 @@ fn test_rolling_var() { // check centered rolling window let out = s - .rolling_var(RollingOptionsImpl { - window_size: Duration::new(4), + .rolling_var(RollingOptionsFixedWindow { + window_size: 4, min_periods: 3, center: true, ..Default::default() diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index cb530b2d1445..d4452c23f913 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -77,6 +77,7 @@ features = [ "reinterpret", "replace", "rolling_window", + "rolling_window_by", "round_series", "row_hash", "rows", diff --git a/py-polars/polars/_utils/deprecation.py b/py-polars/polars/_utils/deprecation.py index b74c1a3a7c07..9c3382f4982d 100644 --- a/py-polars/polars/_utils/deprecation.py +++ b/py-polars/polars/_utils/deprecation.py @@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Callable, Sequence, TypeVar from polars._utils.various import find_stacklevel +from polars.exceptions import InvalidOperationError if TYPE_CHECKING: import sys from typing import Mapping from polars import Expr - from polars.type_aliases import Ambiguous + from polars.type_aliases import Ambiguous, ClosedInterval if sys.version_info >= (3, 10): from typing import ParamSpec @@ -275,3 +276,36 @@ def deprecate_saturating(duration: T) -> T: ) return duration[:-11] # type: ignore[return-value] return duration + + +def validate_rolling_by_aggs_arguments( + weights: list[float] | None, *, center: bool +) -> None: + if weights is not None: + msg = "`weights` is not supported in `rolling_*(..., by=...)` expression" + raise InvalidOperationError(msg) + if center: + msg = "`center=True` is not supported in `rolling_*(..., by=...)` expression" + raise InvalidOperationError(msg) + + +def validate_rolling_aggs_arguments( + window_size: int | str, closed: ClosedInterval | None +) -> int: + if isinstance(window_size, str): + issue_deprecation_warning( + "Passing a str to `rolling_*` is deprecated.\n\n" + "Please, either:\n" + "- pass an integer if you want a fixed window size (e.g. `rolling_mean(3)`)\n" + "- pass a string if you are computing the rolling operation based on another column (e.g. `rolling_mean_by('date', '3d'))\n", + version="0.20.26", + ) + try: + window_size = int(window_size.rstrip("i")) + except ValueError: + msg = f"Expected a string of the form 'ni', where `n` is a positive integer, got: {window_size}" + raise InvalidOperationError(msg) from None + if closed is not None: + msg = "`closed` is not supported in `rolling_*(...)` expression" + raise InvalidOperationError(msg) + return window_size diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index f662b348396a..799ba329509e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -33,6 +33,8 @@ deprecate_renamed_parameter, deprecate_saturating, issue_deprecation_warning, + validate_rolling_aggs_arguments, + validate_rolling_by_aggs_arguments, ) from polars._utils.parse_expr_input import ( parse_as_expression, @@ -6165,7 +6167,7 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self: @unstable() def rolling_min_by( self, - by: str, + by: IntoExpr, window_size: timedelta | str, *, min_periods: int = 1, @@ -6285,12 +6287,11 @@ def rolling_min_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_min( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_min_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -6443,12 +6444,11 @@ def rolling_max_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_max( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_max_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -6603,16 +6603,13 @@ def rolling_mean_by( └───────┴─────────────────────┴──────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_mean( + self._pyexpr.rolling_mean_by( + by, window_size, - None, min_periods, - False, - by, closed, warn_if_unsorted, ) @@ -6767,12 +6764,11 @@ def rolling_sum_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_sum( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_sum_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -6929,16 +6925,13 @@ def rolling_std_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_std( + self._pyexpr.rolling_std_by( + by, window_size, - None, min_periods, - False, - by, closed, ddof, warn_if_unsorted, @@ -7097,16 +7090,13 @@ def rolling_var_by( └───────┴─────────────────────┴─────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_var( + self._pyexpr.rolling_var_by( + by, window_size, - None, min_periods, - False, - by, closed, ddof, warn_if_unsorted, @@ -7238,12 +7228,11 @@ def rolling_median_by( └───────┴─────────────────────┴────────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_median( - window_size, None, min_periods, False, by, closed, warn_if_unsorted + self._pyexpr.rolling_median_by( + by, window_size, min_periods, closed, warn_if_unsorted ) ) @@ -7378,18 +7367,15 @@ def rolling_quantile_by( └───────┴─────────────────────┴──────────────────────┘ """ window_size = deprecate_saturating(window_size) - window_size, min_periods = _prepare_rolling_window_args( - window_size, min_periods - ) + window_size = _prepare_rolling_by_window_args(window_size) + by = parse_as_expression(by) return self._from_pyexpr( - self._pyexpr.rolling_quantile( + self._pyexpr.rolling_quantile_by( + by, quantile, interpolation, window_size, - None, min_periods, - False, - by, closed, warn_if_unsorted, ) @@ -7612,9 +7598,22 @@ def rolling_min( "`rolling_min(..., by='foo')`, please use `rolling_min_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_min_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_min( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -7861,9 +7860,22 @@ def rolling_max( "`rolling_max(..., by='foo')`, please use `rolling_max_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_max_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_max( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -8112,15 +8124,22 @@ def rolling_mean( "`rolling_mean(..., by='foo')`, please use `rolling_mean_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_mean_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_mean( window_size, weights, min_periods, center, - by, - closed, - warn_if_unsorted, ) ) @@ -8367,9 +8386,22 @@ def rolling_sum( "`rolling_sum(..., by='foo')`, please use `rolling_sum_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_sum_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_sum( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -8616,16 +8648,24 @@ def rolling_std( "`rolling_std(..., by='foo')`, please use `rolling_std_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_std_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + ddof=ddof, + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_std( window_size, weights, min_periods, center, - by, - closed, ddof, - warn_if_unsorted, ) ) @@ -8871,16 +8911,24 @@ def rolling_var( "`rolling_var(..., by='foo')`, please use `rolling_var_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_var_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + ddof=ddof, + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_var( window_size, weights, min_periods, center, - by, - closed, ddof, - warn_if_unsorted, ) ) @@ -9046,9 +9094,22 @@ def rolling_median( "`rolling_median(..., by='foo')`, please use `rolling_median_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_median_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_median( - window_size, weights, min_periods, center, by, closed, warn_if_unsorted + window_size, + weights, + min_periods, + center, ) ) @@ -9247,6 +9308,17 @@ def rolling_quantile( "`rolling_quantile(..., by='foo')`, please use `rolling_quantile_by('foo', ...)`.", version="0.20.24", ) + validate_rolling_by_aggs_arguments(weights, center=center) + return self.rolling_quantile_by( + by=by, + # integer `window_size` was already not supported when `by` was passed + window_size=window_size, # type: ignore[arg-type] + min_periods=min_periods, + closed=closed or "right", + warn_if_unsorted=warn_if_unsorted, + quantile=quantile, + ) + window_size = validate_rolling_aggs_arguments(window_size, closed) return self._from_pyexpr( self._pyexpr.rolling_quantile( quantile, @@ -9255,9 +9327,6 @@ def rolling_quantile( weights, min_periods, center, - by, - closed, - warn_if_unsorted, ) ) @@ -11940,7 +12009,7 @@ def _prepare_alpha( def _prepare_rolling_window_args( window_size: int | timedelta | str, min_periods: int | None = None, -) -> tuple[str, int]: +) -> tuple[int | str, int]: if isinstance(window_size, int): if window_size < 1: msg = "`window_size` must be positive" @@ -11948,9 +12017,16 @@ def _prepare_rolling_window_args( if min_periods is None: min_periods = window_size - window_size = f"{window_size}i" elif isinstance(window_size, timedelta): window_size = parse_as_duration_string(window_size) if min_periods is None: min_periods = 1 return window_size, min_periods + + +def _prepare_rolling_by_window_args( + window_size: timedelta | str, +) -> str: + if isinstance(window_size, timedelta): + window_size = parse_as_duration_string(window_size) + return window_size diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index 5c0c24e3a7a5..44af77b2f469 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -10,201 +10,293 @@ use crate::{PyExpr, PySeries}; #[pymethods] impl PyExpr { - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[pyo3(signature = (window_size, weights, min_periods, center))] fn rolling_sum( &self, - window_size: &str, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_sum(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - fn rolling_min( + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_sum_by( &self, + by: PyExpr, window_size: &str, - weights: Option>, min_periods: usize, - center: bool, - by: Option, - closed: Option>, + closed: Wrap, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + self.inner.clone().rolling_sum_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center))] + fn rolling_min( + &self, + window_size: usize, + weights: Option>, + min_periods: usize, + center: bool, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_min(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - fn rolling_max( + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_min_by( &self, + by: PyExpr, window_size: &str, - weights: Option>, min_periods: usize, - center: bool, - by: Option, - closed: Option>, + closed: Wrap, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + self.inner.clone().rolling_min_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center))] + fn rolling_max( + &self, + window_size: usize, + weights: Option>, + min_periods: usize, + center: bool, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_max(options).into() } + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_max_by( + &self, + by: PyExpr, + window_size: &str, + min_periods: usize, + closed: Wrap, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + self.inner.clone().rolling_max_by(by.inner, options).into() + } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[pyo3(signature = (window_size, weights, min_periods, center))] fn rolling_mean( &self, - window_size: &str, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), - warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_mean(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof, warn_if_unsorted))] - fn rolling_std( + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_mean_by( &self, + by: PyExpr, window_size: &str, + min_periods: usize, + closed: Wrap, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + warn_if_unsorted, + fn_params: None, + }; + + self.inner.clone().rolling_mean_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] + fn rolling_std( + &self, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, ddof: u8, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), - warn_if_unsorted, }; self.inner.clone().rolling_std(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof, warn_if_unsorted))] - fn rolling_var( + #[pyo3(signature = (by, window_size, min_periods, closed, ddof, warn_if_unsorted))] + fn rolling_std_by( &self, + by: PyExpr, window_size: &str, - weights: Option>, min_periods: usize, - center: bool, - by: Option, - closed: Option>, + closed: Wrap, ddof: u8, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), + warn_if_unsorted, + }; + + self.inner.clone().rolling_std_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center, ddof))] + fn rolling_var( + &self, + window_size: usize, + weights: Option>, + min_periods: usize, + center: bool, + ddof: u8, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), - warn_if_unsorted, }; self.inner.clone().rolling_var(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - fn rolling_median( + #[pyo3(signature = (by, window_size, min_periods, closed, ddof, warn_if_unsorted))] + fn rolling_var_by( &self, + by: PyExpr, window_size: &str, + min_periods: usize, + closed: Wrap, + ddof: u8, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), + warn_if_unsorted, + }; + + self.inner.clone().rolling_var_by(by.inner, options).into() + } + + #[pyo3(signature = (window_size, weights, min_periods, center))] + fn rolling_median( + &self, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, + ) -> Self { + let options = RollingOptionsFixedWindow { + window_size, + min_periods, + weights, + center, + fn_params: None, + }; + self.inner.clone().rolling_median(options).into() + } + + #[pyo3(signature = (by, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_median_by( + &self, + by: PyExpr, + window_size: &str, + min_periods: usize, + closed: Wrap, warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { + let options = RollingOptionsDynamicWindow { window_size: Duration::parse(window_size), - weights, min_periods, - center, - by, - closed_window: closed.map(|c| c.0), + closed_window: closed.0, fn_params: None, warn_if_unsorted, }; - self.inner.clone().rolling_median(options).into() + self.inner + .clone() + .rolling_median_by(by.inner, options) + .into() } - #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center))] fn rolling_quantile( &self, quantile: f64, interpolation: Wrap, - window_size: &str, + window_size: usize, weights: Option>, min_periods: usize, center: bool, - by: Option, - closed: Option>, - warn_if_unsorted: bool, ) -> Self { - let options = RollingOptions { - window_size: Duration::parse(window_size), + let options = RollingOptionsFixedWindow { + window_size, weights, min_periods, center, - by, - closed_window: closed.map(|c| c.0), fn_params: None, - warn_if_unsorted, }; self.inner @@ -213,6 +305,31 @@ impl PyExpr { .into() } + #[pyo3(signature = (by, quantile, interpolation, window_size, min_periods, closed, warn_if_unsorted))] + fn rolling_quantile_by( + &self, + by: PyExpr, + quantile: f64, + interpolation: Wrap, + window_size: &str, + min_periods: usize, + closed: Wrap, + warn_if_unsorted: bool, + ) -> Self { + let options = RollingOptionsDynamicWindow { + window_size: Duration::parse(window_size), + min_periods, + closed_window: closed.0, + fn_params: None, + warn_if_unsorted, + }; + + self.inner + .clone() + .rolling_quantile_by(by.inner, interpolation.0, quantile, options) + .into() + } + fn rolling_skew(&self, window_size: usize, bias: bool) -> Self { self.inner.clone().rolling_skew(window_size, bias).into() } diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 5fd75d05bbc0..c5db4007dd16 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -1,5 +1,6 @@ use polars_core::series::IsSorted; use polars_plan::dsl::function_expr::rolling::RollingFunction; +use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy; use polars_plan::dsl::function_expr::trigonometry::TrigonometricFunction; use polars_plan::dsl::BooleanFunction; use polars_plan::prelude::{ @@ -628,49 +629,51 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { RollingFunction::Min(_) => { return Err(PyNotImplementedError::new_err("rolling min")) }, - RollingFunction::MinBy(_) => { - return Err(PyNotImplementedError::new_err("rolling min by")) - }, RollingFunction::Max(_) => { return Err(PyNotImplementedError::new_err("rolling max")) }, - RollingFunction::MaxBy(_) => { - return Err(PyNotImplementedError::new_err("rolling max by")) - }, RollingFunction::Mean(_) => { return Err(PyNotImplementedError::new_err("rolling mean")) }, - RollingFunction::MeanBy(_) => { - return Err(PyNotImplementedError::new_err("rolling mean by")) - }, RollingFunction::Sum(_) => { return Err(PyNotImplementedError::new_err("rolling sum")) }, - RollingFunction::SumBy(_) => { - return Err(PyNotImplementedError::new_err("rolling sum by")) - }, RollingFunction::Quantile(_) => { return Err(PyNotImplementedError::new_err("rolling quantile")) }, - RollingFunction::QuantileBy(_) => { - return Err(PyNotImplementedError::new_err("rolling quantile by")) - }, RollingFunction::Var(_) => { return Err(PyNotImplementedError::new_err("rolling var")) }, - RollingFunction::VarBy(_) => { - return Err(PyNotImplementedError::new_err("rolling var by")) - }, RollingFunction::Std(_) => { return Err(PyNotImplementedError::new_err("rolling std")) }, - RollingFunction::StdBy(_) => { - return Err(PyNotImplementedError::new_err("rolling std by")) - }, RollingFunction::Skew(_, _) => { return Err(PyNotImplementedError::new_err("rolling skew")) }, }, + FunctionExpr::RollingExprBy(rolling) => match rolling { + RollingFunctionBy::MinBy(_) => { + return Err(PyNotImplementedError::new_err("rolling min by")) + }, + RollingFunctionBy::MaxBy(_) => { + return Err(PyNotImplementedError::new_err("rolling max by")) + }, + RollingFunctionBy::MeanBy(_) => { + return Err(PyNotImplementedError::new_err("rolling mean by")) + }, + RollingFunctionBy::SumBy(_) => { + return Err(PyNotImplementedError::new_err("rolling sum by")) + }, + RollingFunctionBy::QuantileBy(_) => { + return Err(PyNotImplementedError::new_err("rolling quantile by")) + }, + RollingFunctionBy::VarBy(_) => { + return Err(PyNotImplementedError::new_err("rolling var by")) + }, + RollingFunctionBy::StdBy(_) => { + return Err(PyNotImplementedError::new_err("rolling std by")) + }, + }, FunctionExpr::ShiftAndFill => { return Err(PyNotImplementedError::new_err("shift and fill")) }, diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index bc69f1d5ca0b..8898c8a29d31 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -52,6 +52,9 @@ def test_rolling_kernels_and_rolling( pl.col("values").rolling_var_by("dt", period, closed=closed).alias("var"), pl.col("values").rolling_mean_by("dt", period, closed=closed).alias("mean"), pl.col("values").rolling_std_by("dt", period, closed=closed).alias("std"), + pl.col("values") + .rolling_quantile_by("dt", period, quantile=0.2, closed=closed) + .alias("quantile"), ] ) out2 = ( @@ -63,6 +66,7 @@ def test_rolling_kernels_and_rolling( pl.col("values").var().alias("var"), pl.col("values").mean().alias("mean"), pl.col("values").std().alias("std"), + pl.col("values").quantile(quantile=0.2).alias("quantile"), ] ) ) @@ -220,13 +224,13 @@ def test_rolling_crossing_dst( def test_rolling_by_invalid() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") - msg = "in `rolling_min` operation, `by` argument of dtype `i64` is not supported" + msg = r"in `rolling_\*_by` operation, `by` argument of dtype `i64` is not supported" with pytest.raises(InvalidOperationError, match=msg): - df.select(pl.col("b").rolling_min_by("a", 2)) # type: ignore[arg-type] + df.select(pl.col("b").rolling_min_by("a", "2i")) df = pl.DataFrame({"a": [1, 2, 3], "b": [date(2020, 1, 1)] * 3}).sort("b") - msg = "if `by` argument is passed, then `window_size` must be a temporal window" + msg = "`window_size` duration may not be a parsed integer" with pytest.raises(InvalidOperationError, match=msg): - df.select(pl.col("a").rolling_min_by("b", 2)) # type: ignore[arg-type] + df.select(pl.col("a").rolling_min_by("b", "2i")) def test_rolling_infinity() -> None: @@ -240,7 +244,10 @@ def test_rolling_invalid_closed_option() -> None: df = pl.DataFrame( {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ).sort("a", "b") - with pytest.raises(InvalidOperationError, match="consider using DataFrame.rolling"): + with pytest.raises( + InvalidOperationError, + match=r"`closed` is not supported in `rolling_\*\(...\)` expression", + ): df.with_columns(pl.col("a").rolling_sum(2, closed="left")) @@ -248,21 +255,31 @@ def test_rolling_by_non_temporal_window_size() -> None: df = pl.DataFrame( {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ).sort("a", "b") - msg = "if `by` argument is passed, then `window_size` must be a temporal window" + msg = "`window_size` duration may not be a parsed integer" with pytest.raises(InvalidOperationError, match=msg): - df.with_columns(pl.col("a").rolling_sum_by("b", 2, closed="left")) # type: ignore[arg-type] + df.with_columns(pl.col("a").rolling_sum_by("b", "2i", closed="left")) def test_rolling_by_weights() -> None: df = pl.DataFrame( {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ).sort("b") - msg = r"`weights` is not supported in 'rolling_\*\(..., by=...\)' expression" + msg = r"`weights` is not supported in `rolling_\*\(..., by=...\)` expression" with pytest.raises(InvalidOperationError, match=msg): # noqa: SIM117 with pytest.deprecated_call(match="rolling_sum_by"): df.with_columns(pl.col("a").rolling_sum("2d", by="b", weights=[1, 2])) +def test_rolling_by_center() -> None: + df = pl.DataFrame( + {"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ).sort("b") + msg = r"`center=True` is not supported in `rolling_\*\(..., by=...\)` expression" + with pytest.raises(InvalidOperationError, match=msg): # noqa: SIM117 + with pytest.deprecated_call(match="rolling_sum_by"): + df.with_columns(pl.col("a").rolling_sum("2d", by="b", center=True)) + + def test_rolling_extrema() -> None: # sorted data and nulls flags trigger different kernels df = ( @@ -566,11 +583,15 @@ def test_rolling_negative_period() -> None: df.lazy().rolling("ts", period="-1d", offset="-1d").agg( pl.col("value") ).collect() - with pytest.raises(ComputeError, match="window size should be strictly positive"): + with pytest.raises( + InvalidOperationError, match="`window_size` must be strictly positive" + ): df.select( pl.col("value").rolling_min_by("ts", window_size="-1d", closed="left") ) - with pytest.raises(ComputeError, match="window size should be strictly positive"): + with pytest.raises( + InvalidOperationError, match="`window_size` must be strictly positive" + ): df.lazy().select( pl.col("value").rolling_min_by("ts", window_size="-1d", closed="left") ).collect() @@ -984,7 +1005,10 @@ def test_temporal_windows_size_without_by_15977() -> None: df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ) - with pytest.raises( - pl.InvalidOperationError, match="the `by` argument must be passed" + with pytest.raises( # noqa: SIM117 + InvalidOperationError, match="Expected a string of the form 'ni'" ): - df.select(pl.col("a").rolling_mean("3d")) + with pytest.deprecated_call( + match=r"Passing a str to `rolling_\*` is deprecated" + ): + df.select(pl.col("a").rolling_mean("3d"))