diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 6e5530249a7fa..3ee24266444e6 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -84,6 +84,7 @@ where } let ca = ca.rechunk(); let by = by.rechunk(); + polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column"); ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?; polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); let (by, tz) = match by.dtype() { diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs index d1db4448aa2de..59a798ced3bd3 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs @@ -79,24 +79,25 @@ where let mut out: Vec<_> = vec![None; values.len()]; offsets.enumerate().for_each(|(idx, result)| { - let _ = result.map(|(start, len)| { + if let Ok((start, len)) = result { let end = start + len; // On the Python side, if `min_periods` wasn't specified, it is set to // `1`. In that case, this condition is the same as checking // `if start == end`. - if len < (min_periods as IdxSize) { - {} - } else { + if len >= (min_periods as IdxSize) { // SAFETY: - // we are in bounds + // we are in bound let res = unsafe { agg_window.update(start as usize, end as usize) }; + + // SAFETY: `idx` is in bounds because `sorting_indices` was just taken from + // `by`, which has already been checked to be the same length as the values. unsafe { let out_idx = sorting_indices.get_unchecked(idx); *out.get_unchecked_mut(*out_idx as usize) = res; } } - }); + } }); let out = PrimitiveArray::::from(out); diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 8a63fa6ddd4ec..e23cce0412c10 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -1041,6 +1041,11 @@ def test_incorrect_nulls_16246() -> None: expected = pl.DataFrame({'b': [1,1]}) assert_frame_equal(result, expected) +def test_by_different_length() -> None: + df = pl.DataFrame({'b': [1]}) + with pytest.raises(InvalidOperationError, match='must be the same length'): + df.select(pl.col('b').rolling_max_by(pl.Series([datetime(2020, 1, 1)]*2), '1d')) + def test_incorrect_nulls_16246() -> None: df = pl.concat( [