Skip to content

Commit

Permalink
feat: check if by column is sorted, rather than just checking sorte…
Browse files Browse the repository at this point in the history
…d flag, in `group_by_dynamic`, `upsample`, and `rolling`
  • Loading branch information
MarcoGorelli committed May 27, 2024
1 parent d856b49 commit af80d86
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 106 deletions.
10 changes: 0 additions & 10 deletions crates/polars-core/src/utils/series.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::prelude::*;
use crate::series::unstable::UnstableSeries;
use crate::series::IsSorted;

/// A utility that allocates an [`UnstableSeries`]. The applied function can then use that
/// series container to save heap allocations and swap arrow arrays.
Expand All @@ -14,15 +13,6 @@ where
f(&mut us)
}

pub fn ensure_sorted_arg(s: &Series, operation: &str) -> PolarsResult<()> {
polars_ensure!(!matches!(s.is_sorted_flag(), IsSorted::Not), InvalidOperation: "argument in operation '{}' is not explicitly sorted
- If your data is ALREADY sorted, set the sorted flag with: '.set_sorted()'.
- If your data is NOT sorted, sort the 'expr/series/column' first.
", operation);
Ok(())
}

pub fn handle_casting_failures(input: &Series, output: &Series) -> PolarsResult<()> {
let failure_mask = !input.is_null() & output.is_null();
let failures = input.filter_threaded(&failure_mask, false)?;
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-ops/src/frame/join/asof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::borrow::Cow;
use default::*;
pub use groups::AsofJoinBy;
use polars_core::prelude::*;
use polars_core::utils::ensure_sorted_arg;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use smartstring::alias::String as SmartString;
Expand All @@ -14,6 +13,7 @@ use smartstring::alias::String as SmartString;
use super::_check_categorical_src;
use super::{_finish_join, build_tables, prepare_bytes};
use crate::frame::IntoDf;
use crate::series::SeriesMethods;

trait AsofJoinState<T>: Default {
fn next<F: FnMut(IdxSize) -> Option<T>>(
Expand Down Expand Up @@ -185,8 +185,8 @@ fn check_asof_columns(
a.dtype(), b.dtype()
);
if check_sorted {
ensure_sorted_arg(a, "asof_join")?;
ensure_sorted_arg(b, "asof_join")?;
a.ensure_sorted_arg("asof_join")?;
b.ensure_sorted_arg("asof_join")?;
}
Ok(())
}
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-ops/src/series/ops/various.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ pub trait SeriesMethods: SeriesSealed {
}
}

fn ensure_sorted_arg(&self, operation: &str) -> PolarsResult<()> {
polars_ensure!(self.is_sorted(Default::default())?, InvalidOperation: "argument in operation '{}' is not sorted, please sort the 'expr/series/column' first", operation);
Ok(())
}

/// Checks if a [`Series`] is sorted. Tries to fail fast.
fn is_sorted(&self, options: SortOptions) -> PolarsResult<bool> {
let s = self.as_series();
Expand Down
34 changes: 8 additions & 26 deletions crates/polars-time/src/group_by/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use arrow::legacy::utils::CustomIterTools;
use polars_core::export::rayon::prelude::*;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::ensure_sorted_arg;
use polars_core::utils::flatten::flatten_par;
use polars_core::POOL;
use polars_ops::series::SeriesMethods;
use polars_utils::idx_vec::IdxVec;
use polars_utils::slice::{GetSaferUnchecked, SortedSlice};
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -34,9 +34,6 @@ pub struct DynamicGroupOptions {
pub include_boundaries: bool,
pub closed_window: ClosedWindow,
pub start_by: StartBy,
/// In cases sortedness cannot be checked by the sorted flag,
/// traverse the data to check sortedness.
pub check_sorted: bool,
}

impl Default for DynamicGroupOptions {
Expand All @@ -50,7 +47,6 @@ impl Default for DynamicGroupOptions {
include_boundaries: false,
closed_window: ClosedWindow::Left,
start_by: Default::default(),
check_sorted: true,
}
}
}
Expand All @@ -64,9 +60,6 @@ pub struct RollingGroupOptions {
pub period: Duration,
pub offset: Duration,
pub closed_window: ClosedWindow,
/// In cases sortedness cannot be checked by the sorted flag,
/// traverse the data to check sortedness.
pub check_sorted: bool,
}

impl Default for RollingGroupOptions {
Expand All @@ -76,7 +69,6 @@ impl Default for RollingGroupOptions {
period: Duration::new(1),
offset: Duration::new(1),
closed_window: ClosedWindow::Left,
check_sorted: true,
}
}
}
Expand Down Expand Up @@ -133,10 +125,10 @@ impl Wrap<&DataFrame> {
"rolling window period should be strictly positive",
);
let time = self.0.column(&options.index_column)?.clone();
if group_by.is_empty() && options.check_sorted {
if group_by.is_empty() {
// If by is given, the column must be sorted in the 'by' arg, which we can not check now
// this will be checked when the groups are materialized.
ensure_sorted_arg(&time, "rolling")?;
time.ensure_sorted_arg("rolling")?;
}
let time_type = time.dtype();

Expand Down Expand Up @@ -202,10 +194,10 @@ impl Wrap<&DataFrame> {
options: &DynamicGroupOptions,
) -> PolarsResult<(Series, Vec<Series>, GroupsProxy)> {
let time = self.0.column(&options.index_column)?.rechunk();
if group_by.is_empty() && options.check_sorted {
if group_by.is_empty() {
// If by is given, the column must be sorted in the 'by' arg, which we can not check now
// this will be checked when the groups are materialized.
ensure_sorted_arg(&time, "group_by_dynamic")?;
time.ensure_sorted_arg("group_by_dynamic")?;
}
let time_type = time.dtype();

Expand Down Expand Up @@ -349,9 +341,7 @@ impl Wrap<&DataFrame> {
let dt = unsafe { dt.take_unchecked(base_g.1) };
let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
if options.check_sorted
&& !matches!(dt.is_sorted_flag(), IsSorted::Ascending)
{
if !matches!(dt.is_sorted_flag(), IsSorted::Ascending) {
check_sortedness_slice(ts)?
}
let (sub_groups, lower, upper) = group_by_windows(
Expand Down Expand Up @@ -428,9 +418,7 @@ impl Wrap<&DataFrame> {
let dt = unsafe { dt.take_unchecked(base_g.1) };
let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
if options.check_sorted
&& !matches!(dt.is_sorted_flag(), IsSorted::Ascending)
{
if !matches!(dt.is_sorted_flag(), IsSorted::Ascending) {
check_sortedness_slice(ts)?
}
let (sub_groups, _, _) = group_by_windows(
Expand Down Expand Up @@ -573,9 +561,7 @@ impl Wrap<&DataFrame> {
let dt = unsafe { dt_local.take_unchecked(base_g.1) };
let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
if options.check_sorted
&& !matches!(dt.is_sorted_flag(), IsSorted::Ascending)
{
if !matches!(dt.is_sorted_flag(), IsSorted::Ascending) {
check_sortedness_slice(ts)?
}

Expand Down Expand Up @@ -716,7 +702,6 @@ mod test {
period: Duration::parse("2d"),
offset: Duration::parse("-2d"),
closed_window: ClosedWindow::Right,
..Default::default()
},
)
.unwrap();
Expand Down Expand Up @@ -764,7 +749,6 @@ mod test {
period: Duration::parse("2d"),
offset: Duration::parse("-2d"),
closed_window: ClosedWindow::Right,
..Default::default()
},
)
.unwrap();
Expand Down Expand Up @@ -848,7 +832,6 @@ mod test {
include_boundaries: true,
closed_window: ClosedWindow::Both,
start_by: Default::default(),
..Default::default()
},
)
.unwrap();
Expand Down Expand Up @@ -969,7 +952,6 @@ mod test {
include_boundaries: true,
closed_window: ClosedWindow::Both,
start_by: Default::default(),
..Default::default()
},
)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-time/src/upsample.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#[cfg(feature = "timezones")]
use polars_core::chunked_array::temporal::parse_time_zone;
use polars_core::prelude::*;
use polars_core::utils::ensure_sorted_arg;
use polars_ops::prelude::*;
use polars_ops::series::SeriesMethods;

use crate::prelude::*;

Expand Down Expand Up @@ -128,7 +128,7 @@ fn upsample_impl(
stable: bool,
) -> PolarsResult<DataFrame> {
let s = source.column(index_column)?;
ensure_sorted_arg(s, "upsample")?;
s.ensure_sorted_arg("upsample")?;
let time_type = s.dtype();
if matches!(time_type, DataType::Date) {
let mut df = source.clone();
Expand Down
18 changes: 13 additions & 5 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5472,7 +5472,7 @@ def rolling(
offset: str | timedelta | None = None,
closed: ClosedInterval = "right",
group_by: IntoExpr | Iterable[IntoExpr] | None = None,
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> RollingGroupBy:
"""
Create rolling groups based on a temporal or integer column.
Expand Down Expand Up @@ -5547,6 +5547,10 @@ def rolling(
data within the groups is sorted, you can set this to `False`.
Doing so incorrectly will lead to incorrect output
.. deprecated:: 0.20.31
Sortedness is now verified in a quick manner, you can safely remove
this argument.
Returns
-------
RollingGroupBy
Expand Down Expand Up @@ -5622,7 +5626,7 @@ def group_by_dynamic(
label: Label = "left",
group_by: IntoExpr | Iterable[IntoExpr] | None = None,
start_by: StartBy = "window",
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> DynamicGroupBy:
"""
Group based on a time value (or index value of type Int32, Int64).
Expand Down Expand Up @@ -5707,6 +5711,10 @@ def group_by_dynamic(
data within the groups is sorted, you can set this to `False`.
Doing so incorrectly will lead to incorrect output
.. deprecated:: 0.20.31
Sortedness is now verified in a quick manner, you can safely remove
this argument.
Returns
-------
DynamicGroupBy
Expand Down Expand Up @@ -10733,7 +10741,7 @@ def groupby_rolling(
offset: str | timedelta | None = None,
closed: ClosedInterval = "right",
by: IntoExpr | Iterable[IntoExpr] | None = None,
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> RollingGroupBy:
"""
Create rolling groups based on a time, Int32, or Int64 column.
Expand Down Expand Up @@ -10787,7 +10795,7 @@ def group_by_rolling(
offset: str | timedelta | None = None,
closed: ClosedInterval = "right",
by: IntoExpr | Iterable[IntoExpr] | None = None,
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> RollingGroupBy:
"""
Create rolling groups based on a time, Int32, or Int64 column.
Expand Down Expand Up @@ -10845,7 +10853,7 @@ def groupby_dynamic(
closed: ClosedInterval = "left",
by: IntoExpr | Iterable[IntoExpr] | None = None,
start_by: StartBy = "window",
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> DynamicGroupBy:
"""
Group based on a time value (or index value of type Int32, Int64).
Expand Down
22 changes: 12 additions & 10 deletions py-polars/polars/dataframe/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,13 @@ def __init__(
offset: str | timedelta | None,
closed: ClosedInterval,
group_by: IntoExpr | Iterable[IntoExpr] | None,
check_sorted: bool,
check_sorted: bool | None = None,
):
if check_sorted is not None:
issue_deprecation_warning(
"`check_sorted` is now deprecated in `rolling`, you can safely remove this argument.",
version="0.20.31",
)
period = parse_as_duration_string(period)
offset = parse_as_duration_string(offset)

Expand All @@ -817,7 +822,6 @@ def __init__(
self.offset = offset
self.closed = closed
self.group_by = group_by
self.check_sorted = check_sorted

def __iter__(self) -> Self:
temp_col = "__POLARS_GB_GROUP_INDICES"
Expand All @@ -829,7 +833,6 @@ def __iter__(self) -> Self:
offset=self.offset,
closed=self.closed,
group_by=self.group_by,
check_sorted=self.check_sorted,
)
.agg(F.first().agg_groups().alias(temp_col))
.collect(no_optimization=True)
Expand Down Expand Up @@ -888,7 +891,6 @@ def agg(
offset=self.offset,
closed=self.closed,
group_by=self.group_by,
check_sorted=self.check_sorted,
)
.agg(*aggs, **named_aggs)
.collect(no_optimization=True)
Expand Down Expand Up @@ -931,7 +933,6 @@ def map_groups(
offset=self.offset,
closed=self.closed,
group_by=self.group_by,
check_sorted=self.check_sorted,
)
.map_groups(function, schema)
.collect(no_optimization=True)
Expand Down Expand Up @@ -983,8 +984,13 @@ def __init__(
label: Label,
group_by: IntoExpr | Iterable[IntoExpr] | None,
start_by: StartBy,
check_sorted: bool,
check_sorted: bool | None = None,
):
if check_sorted is not None:
issue_deprecation_warning(
"`check_sorted` is now deprecated in `rolling`, you can safely remove this argument.",
version="0.20.31",
)
every = parse_as_duration_string(every)
period = parse_as_duration_string(period)
offset = parse_as_duration_string(offset)
Expand All @@ -1000,7 +1006,6 @@ def __init__(
self.closed = closed
self.group_by = group_by
self.start_by = start_by
self.check_sorted = check_sorted

def __iter__(self) -> Self:
temp_col = "__POLARS_GB_GROUP_INDICES"
Expand All @@ -1017,7 +1022,6 @@ def __iter__(self) -> Self:
closed=self.closed,
group_by=self.group_by,
start_by=self.start_by,
check_sorted=self.check_sorted,
)
.agg(F.first().agg_groups().alias(temp_col))
.collect(no_optimization=True)
Expand Down Expand Up @@ -1081,7 +1085,6 @@ def agg(
closed=self.closed,
group_by=self.group_by,
start_by=self.start_by,
check_sorted=self.check_sorted,
)
.agg(*aggs, **named_aggs)
.collect(no_optimization=True)
Expand Down Expand Up @@ -1128,7 +1131,6 @@ def map_groups(
closed=self.closed,
group_by=self.group_by,
start_by=self.start_by,
check_sorted=self.check_sorted,
)
.map_groups(function, schema)
.collect(no_optimization=True)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3771,7 +3771,7 @@ def rolling(
period: str | timedelta,
offset: str | timedelta | None = None,
closed: ClosedInterval = "right",
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> Self:
"""
Create rolling groups based on a temporal or integer column.
Expand Down Expand Up @@ -3875,7 +3875,7 @@ def rolling(
offset = parse_as_duration_string(offset)

return self._from_pyexpr(
self._pyexpr.rolling(index_column, period, offset, closed, check_sorted)
self._pyexpr.rolling(index_column, period, offset, closed)
)

def is_unique(self) -> Self:
Expand Down
Loading

0 comments on commit af80d86

Please sign in to comment.