Skip to content

Commit

Permalink
feat(rust!): split out rolling_*_by from rolling_* aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 10, 2024
1 parent d341156 commit 01d74de
Show file tree
Hide file tree
Showing 26 changed files with 1,021 additions and 618 deletions.
2 changes: 2 additions & 0 deletions crates/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ take_opt_iter = []
group_by_list = []
# rolling window functions
rolling_window = []
rolling_window_by = []
diagonal_concat = []
dataframe_arithmetic = []
product = []
Expand Down Expand Up @@ -135,6 +136,7 @@ docs-selection = [
"dot_product",
"row_hash",
"rolling_window",
"rolling_window_by",
"dtype-categorical",
"dtype-decimal",
"diagonal_concat",
Expand Down
18 changes: 17 additions & 1 deletion crates/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use arrow::legacy::prelude::DynArgs;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[derive(Clone)]
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingOptionsFixedWindow {
/// The length of the window.
pub window_size: usize,
Expand All @@ -11,9 +14,22 @@ pub struct RollingOptionsFixedWindow {
pub weights: Option<Vec<f64>>,
/// Set the labels at the center of the window.
pub center: bool,
#[cfg_attr(feature = "serde", serde(skip))]
pub fn_params: DynArgs,
}

#[cfg(feature = "rolling_window")]
impl PartialEq for RollingOptionsFixedWindow {
fn eq(&self, other: &Self) -> bool {
self.window_size == other.window_size
&& self.min_periods == other.min_periods
&& self.weights == other.weights
&& self.center == other.center
&& self.fn_params.is_none()
&& other.fn_params.is_none()
}
}

impl Default for RollingOptionsFixedWindow {
fn default() -> Self {
RollingOptionsFixedWindow {
Expand Down
6 changes: 5 additions & 1 deletion crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ cum_agg = ["polars-plan/cum_agg"]
interpolate = ["polars-plan/interpolate"]
rolling_window = [
"polars-plan/rolling_window",
"polars-time/rolling_window",
]
rolling_window_by = [
"polars-plan/rolling_window_by",
"polars-time/rolling_window_by",
]
rank = ["polars-plan/rank"]
diff = ["polars-plan/diff", "polars-plan/diff"]
Expand Down Expand Up @@ -292,6 +295,7 @@ features = [
"replace",
"rle",
"rolling_window",
"rolling_window_by",
"round_series",
"row_hash",
"search_sorted",
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ pub use polars_plan::logical_plan::{
};
pub use polars_plan::prelude::UnionArgs;
pub(crate) use polars_plan::prelude::*;
#[cfg(feature = "rolling_window")]
pub use polars_time::{prelude::RollingOptions, Duration};
#[cfg(feature = "rolling_window_by")]
pub use polars_time::Duration;
#[cfg(feature = "dynamic_group_by")]
pub use polars_time::{DynamicGroupOptions, PolarsTemporalGroupby, RollingGroupOptions};
pub(crate) use polars_utils::arena::{Arena, Node};
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ fn test_power_in_agg_list1() -> PolarsResult<()> {
.group_by([col("fruits")])
.agg([
col("A")
.rolling_min(RollingOptions {
window_size: Duration::new(1),
.rolling_min(RollingOptionsFixedWindow {
window_size: 1,
..Default::default()
})
.alias("input"),
col("A")
.rolling_min(RollingOptions {
window_size: Duration::new(1),
.rolling_min(RollingOptionsFixedWindow {
window_size: 1,
..Default::default()
})
.pow(2.0)
Expand Down Expand Up @@ -211,8 +211,8 @@ fn test_power_in_agg_list2() -> PolarsResult<()> {
.lazy()
.group_by([col("fruits")])
.agg([col("A")
.rolling_min(RollingOptions {
window_size: Duration::new(2),
.rolling_min(RollingOptionsFixedWindow {
window_size: 2,
min_periods: 2,
..Default::default()
})
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ log = []
hash = []
reinterpret = ["polars-core/reinterpret"]
rolling_window = ["polars-core/rolling_window"]
rolling_window_by = ["polars-core/rolling_window_by"]
moment = []
mode = []
search_sorted = []
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ rolling_window = [
"polars-core/rolling_window",
"polars-time/rolling_window",
"polars-ops/rolling_window",
"polars-time/rolling_window",
]
rolling_window_by = [
"polars-core/rolling_window_by",
"polars-time/rolling_window_by",
"polars-ops/rolling_window_by",
]
rank = ["polars-ops/rank"]
diff = ["polars-ops/diff"]
Expand Down Expand Up @@ -180,6 +184,7 @@ features = [
"temporal",
"serde",
"rolling_window",
"rolling_window_by",
"timezones",
"dtype-date",
"extract_groups",
Expand Down
36 changes: 27 additions & 9 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ mod random;
mod range;
#[cfg(feature = "rolling_window")]
pub mod rolling;
#[cfg(feature = "rolling_window_by")]
pub mod rolling_by;
#[cfg(feature = "round_series")]
mod round;
#[cfg(feature = "row_hash")]
Expand Down Expand Up @@ -96,6 +98,8 @@ pub use self::pow::PowFunction;
pub(super) use self::range::RangeFunction;
#[cfg(feature = "rolling_window")]
pub(super) use self::rolling::RollingFunction;
#[cfg(feature = "rolling_window_by")]
pub(super) use self::rolling_by::RollingFunctionBy;
#[cfg(feature = "strings")]
pub(crate) use self::strings::StringFunction;
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -156,6 +160,8 @@ pub enum FunctionExpr {
FillNullWithStrategy(FillNullStrategy),
#[cfg(feature = "rolling_window")]
RollingExpr(RollingFunction),
#[cfg(feature = "rolling_window_by")]
RollingExprBy(RollingFunctionBy),
ShiftAndFill,
Shift,
DropNans,
Expand Down Expand Up @@ -420,6 +426,10 @@ impl Hash for FunctionExpr {
RollingExpr(f) => {
f.hash(state);
},
#[cfg(feature = "rolling_window_by")]
RollingExprBy(f) => {
f.hash(state);
},
#[cfg(feature = "moment")]
Skew(a) => a.hash(state),
#[cfg(feature = "moment")]
Expand Down Expand Up @@ -609,6 +619,8 @@ impl Display for FunctionExpr {
FillNull { .. } => "fill_null",
#[cfg(feature = "rolling_window")]
RollingExpr(func, ..) => return write!(f, "{func}"),
#[cfg(feature = "rolling_window_by")]
RollingExprBy(func, ..) => return write!(f, "{func}"),
ShiftAndFill => "shift_and_fill",
DropNans => "drop_nans",
DropNulls => "drop_nulls",
Expand Down Expand Up @@ -907,25 +919,31 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
use RollingFunction::*;
match f {
Min(options) => map!(rolling::rolling_min, options.clone()),
MinBy(options) => map_as_slice!(rolling::rolling_min_by, options.clone()),
Max(options) => map!(rolling::rolling_max, options.clone()),
MaxBy(options) => map_as_slice!(rolling::rolling_max_by, options.clone()),
Mean(options) => map!(rolling::rolling_mean, options.clone()),
MeanBy(options) => map_as_slice!(rolling::rolling_mean_by, options.clone()),
Sum(options) => map!(rolling::rolling_sum, options.clone()),
SumBy(options) => map_as_slice!(rolling::rolling_sum_by, options.clone()),
Quantile(options) => map!(rolling::rolling_quantile, options.clone()),
QuantileBy(options) => {
map_as_slice!(rolling::rolling_quantile_by, options.clone())
},
Var(options) => map!(rolling::rolling_var, options.clone()),
VarBy(options) => map_as_slice!(rolling::rolling_var_by, options.clone()),
Std(options) => map!(rolling::rolling_std, options.clone()),
StdBy(options) => map_as_slice!(rolling::rolling_std_by, options.clone()),
#[cfg(feature = "moment")]
Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias),
}
},
#[cfg(feature = "rolling_window_by")]
RollingExprBy(f) => {
use RollingFunctionBy::*;
match f {
MinBy(options) => map_as_slice!(rolling_by::rolling_min_by, options.clone()),
MaxBy(options) => map_as_slice!(rolling_by::rolling_max_by, options.clone()),
MeanBy(options) => map_as_slice!(rolling_by::rolling_mean_by, options.clone()),
SumBy(options) => map_as_slice!(rolling_by::rolling_sum_by, options.clone()),
QuantileBy(options) => {
map_as_slice!(rolling_by::rolling_quantile_by, options.clone())
},
VarBy(options) => map_as_slice!(rolling_by::rolling_var_by, options.clone()),
StdBy(options) => map_as_slice!(rolling_by::rolling_std_by, options.clone()),
}
},
#[cfg(feature = "hist")]
Hist {
bin_count,
Expand Down
Loading

0 comments on commit 01d74de

Please sign in to comment.