Skip to content

Commit

Permalink
feat: sort on behalf of user
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 16, 2024
1 parent 98a2d9b commit 665da25
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 173 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/polars-time/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ once_cell = { workspace = true }
regex = { workspace = true }
serde = { workspace = true, optional = true }
smartstring = { workspace = true }
bytemuck = {workspace = true}

[dev-dependencies]
polars-ops = { workspace = true, features = ["abs"] }
Expand Down
88 changes: 53 additions & 35 deletions crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ fn rolling_agg_by<T>(
TimeUnit,
Option<&TimeZone>,
DynArgs,
Option<&[IdxSize]>,
) -> PolarsResult<ArrayRef>,
) -> PolarsResult<Series>
where
Expand All @@ -83,20 +84,9 @@ where
}
let ca = ca.rechunk();
let by = by.rechunk();
polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column");
ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?;
polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive");
if by.is_sorted_flag() != IsSorted::Ascending && options.warn_if_unsorted {
polars_warn!(format!(
"Series is not known to be sorted by `by` column in `rolling_*_by` operation.\n\
\n\
To silence this warning, you may want to try:\n\
- sorting your data by your `by` column beforehand;\n\
- setting `.set_sorted()` if you already know your data is sorted;\n\
- passing `warn_if_unsorted=False` if this warning is a false-positive\n \
(this is known to happen when combining rolling aggregations with `over`);\n\n\
before passing calling the rolling aggregation function.\n",
));
}
let (by, tz) = match by.dtype() {
DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz),
DataType::Date => (
Expand All @@ -109,32 +99,60 @@ where
"date/datetime"),
};
let by = by.datetime().unwrap();
let by_values = by.cont_slice().map_err(|_| {
polars_err!(
ComputeError:
"`by` column should not have null values in 'rolling by' expression"
)
})?;
let tu = by.time_unit();

let arr = ca.downcast_iter().next().unwrap();
if arr.null_count() > 0 {
polars_bail!(InvalidOperation: "'Expr.rolling_*(..., by=...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'")
}
let values = arr.values().as_slice();
let func = rolling_agg_fn_dynamic;

let arr = func(
values,
options.window_size,
by_values,
options.closed_window,
options.min_periods,
tu,
tz.as_ref(),
options.fn_params,
)?;
Series::try_from((ca.name(), arr))
let out: ArrayRef = if matches!(by.is_sorted_flag(), IsSorted::Ascending) {
let arr = ca.downcast_iter().next().unwrap();
if arr.null_count() > 0 {
polars_bail!(InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'")
}
let by_values = by.cont_slice().map_err(|_| {
polars_err!(
ComputeError:
"`by` column should not have null values in 'rolling by' expression"
)
})?;
let values = arr.values().as_slice();
func(
values,
options.window_size,
by_values,
options.closed_window,
options.min_periods,
tu,
tz.as_ref(),
options.fn_params,
None,
)?
} else {
let sorting_indices = by.arg_sort(Default::default());
let ca = unsafe { ca.take_unchecked(&sorting_indices) };
let by = unsafe { by.take_unchecked(&sorting_indices) };
let arr = ca.downcast_iter().next().unwrap();
if arr.null_count() > 0 {
polars_bail!(InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'")
}
let by_values = by.cont_slice().map_err(|_| {
polars_err!(
ComputeError:
"`by` column should not have null values in 'rolling by' expression"
)
})?;
let values = arr.values().as_slice();
func(
values,
options.window_size,
by_values,
options.closed_window,
options.min_periods,
tu,
tz.as_ref(),
options.fn_params,
Some(sorting_indices.cont_slice().unwrap()),
)?
};
Series::try_from((ca.name(), out))
}

pub trait SeriesOpsTime: AsSeries {
Expand Down
3 changes: 0 additions & 3 deletions crates/polars-time/src/chunkedarray/rolling_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ pub struct RollingOptionsDynamicWindow {
/// Optional parameters for the rolling function
#[cfg_attr(feature = "serde", serde(skip))]
pub fn_params: DynArgs,
/// Warn if data is not known to be sorted by `by` column (if passed)
pub warn_if_unsorted: bool,
}

#[cfg(feature = "rolling_window_by")]
Expand All @@ -33,7 +31,6 @@ impl PartialEq for RollingOptionsDynamicWindow {
self.window_size == other.window_size
&& self.min_periods == other.min_periods
&& self.closed_window == other.closed_window
&& self.warn_if_unsorted == other.warn_if_unsorted
&& self.fn_params.is_none()
&& other.fn_params.is_none()
}
Expand Down
Loading

0 comments on commit 665da25

Please sign in to comment.