Skip to content

Commit

Permalink
test including lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 15, 2024
1 parent aa28909 commit 81e93f1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ where
}
let ca = ca.rechunk();
let by = by.rechunk();
polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column");
ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?;
polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive");
let (by, tz) = match by.dtype() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,25 @@ where

let mut out: Vec<_> = vec![None; values.len()];
offsets.enumerate().for_each(|(idx, result)| {
let _ = result.map(|(start, len)| {
if let Ok((start, len)) = result {
let end = start + len;

// On the Python side, if `min_periods` wasn't specified, it is set to
// `1`. In that case, this condition is the same as checking
// `if start == end`.
if len < (min_periods as IdxSize) {
{}
} else {
if len >= (min_periods as IdxSize) {
// SAFETY:
// we are in bounds
// we are in bound
let res = unsafe { agg_window.update(start as usize, end as usize) };

// SAFETY: `idx` is in bounds because `sorting_indices` was just taken from
// `by`, which has already been checked to be the same length as the values.
unsafe {
let out_idx = sorting_indices.get_unchecked(idx);
*out.get_unchecked_mut(*out_idx as usize) = res;
}
}
});
}
});
let out = PrimitiveArray::<T>::from(out);

Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,11 @@ def test_incorrect_nulls_16246() -> None:
expected = pl.DataFrame({'b': [1,1]})
assert_frame_equal(result, expected)

def test_by_different_length() -> None:
df = pl.DataFrame({'b': [1]})
with pytest.raises(InvalidOperationError, match='must be the same length'):
df.select(pl.col('b').rolling_max_by(pl.Series([datetime(2020, 1, 1)]*2), '1d'))

def test_incorrect_nulls_16246() -> None:
df = pl.concat(
[
Expand Down

0 comments on commit 81e93f1

Please sign in to comment.