Skip to content

Commit

Permalink
fix(rust, python): std when ddof>=n_values returns None even in rolli…
Browse files Browse the repository at this point in the history
…ng context (pola-rs#11750)
  • Loading branch information
MarcoGorelli authored Mar 8, 2024
1 parent de578b5 commit 6b23f79
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ impl<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
let sum = self.sum.update(start, end);
sum / NumCast::from(end - start).unwrap()
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let sum = self.sum.update(start, end).unwrap_unchecked();
Some(sum / NumCast::from(end - start).unwrap())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ macro_rules! minmax_window {
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
//For details see: https://github.com/pola-rs/polars/pull/9277#issuecomment-1581401692
self.last_start = start; // Don't care where the last one started
let old_last_end = self.last_end; // But we need this
Expand All @@ -168,10 +168,10 @@ macro_rules! minmax_window {
if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) {
// The entering extremum "beats" the previous extremum so we can ignore the overlap
self.update_m_and_m_idx(entering.unwrap());
return self.m;
return Some(self.m);
} else if self.m_idx >= start || empty_overlap {
// The previous extremum didn't drop off. Keep it
return self.m;
return Some(self.m);
}
// Otherwise get the min of the overlapping window and the entering min
match (
Expand All @@ -191,7 +191,7 @@ macro_rules! minmax_window {
(None, None) => unreachable!(),
}

self.m
Some(self.m)
}
}
};
Expand Down Expand Up @@ -241,7 +241,7 @@ macro_rules! rolling_minmax_func {
_params: DynArgs,
) -> PolarsResult<ArrayRef>
where
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T>,
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
{
let offset_fn = match center {
true => det_offsets_center,
Expand Down
34 changes: 20 additions & 14 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ mod min_max;
mod quantile;
mod sum;
mod variance;

use std::fmt::Debug;

pub use mean::*;
pub use min_max::*;
use num_traits::{Float, NumCast};
use num_traits::{Float, Num, NumCast};
pub use quantile::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -28,7 +27,7 @@ pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
///
/// # Safety
/// `start` and `end` must be within the windows bounds
unsafe fn update(&mut self, start: usize, end: usize) -> T;
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
}

// Use an aggregation window that maintains the state
Expand All @@ -42,27 +41,34 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Agg: RollingAggWindowNoNulls<'a, T>,
T: Debug + NativeType,
T: Debug + NativeType + Num,
{
let len = values.len();
let (start, end) = det_offsets_fn(0, window_size, len);
let mut agg_window = Agg::new(values, start, end, params);
if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
if validity.iter().all(|x| !x) {
return Ok(Box::new(PrimitiveArray::<T>::new_null(
T::PRIMITIVE.into(),
len,
)));
}
}

let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
// SAFETY:
// we are in bounds
unsafe { agg_window.update(start, end) }
if end - start < min_periods {
None
} else {
// SAFETY:
// we are in bounds
unsafe { agg_window.update(start, end) }
}
})
.collect_trusted::<Vec<_>>();

let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
Ok(Box::new(PrimitiveArray::new(
T::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
)))
let arr = PrimitiveArray::from(out);
Ok(Box::new(arr))
}

#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let vals = self.sorted.update(start, end);
let length = vals.len();

Expand All @@ -48,13 +48,13 @@ impl<
let float_idx_top = (length_f - 1.0) * self.prob;
let top_idx = float_idx_top.ceil() as usize;
return if idx == top_idx {
unsafe { *vals.get_unchecked_release(idx) }
Some(unsafe { *vals.get_unchecked_release(idx) })
} else {
let proportion = T::from(float_idx_top - idx as f64).unwrap();
let vi = unsafe { *vals.get_unchecked_release(idx) };
let vj = unsafe { *vals.get_unchecked_release(top_idx) };

proportion * (vj - vi) + vi
Some(proportion * (vj - vi) + vi)
};
},
Midpoint => {
Expand All @@ -66,7 +66,7 @@ impl<
return if top_idx == idx {
// SAFETY:
// we are in bounds
unsafe { *vals.get_unchecked_release(idx) }
Some(unsafe { *vals.get_unchecked_release(idx) })
} else {
// SAFETY:
// we are in bounds
Expand All @@ -77,7 +77,7 @@ impl<
)
};

(mid + mid_plus_1) / (T::one() + T::one())
Some((mid + mid_plus_1) / (T::one() + T::one()))
};
},
Nearest => {
Expand All @@ -93,7 +93,7 @@ impl<

// SAFETY:
// we are in bounds
unsafe { *vals.get_unchecked_release(idx) }
Some(unsafe { *vals.get_unchecked_release(idx) })
}
}

Expand Down
13 changes: 10 additions & 3 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end {
Expand Down Expand Up @@ -60,7 +60,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
}
}
self.last_end = end;
self.sum
Some(self.sum)
}
}

Expand All @@ -73,7 +73,14 @@ pub fn rolling_sum<T>(
_params: DynArgs,
) -> PolarsResult<ArrayRef>
where
T: NativeType + std::iter::Sum + NumCast + Mul<Output = T> + AddAssign + SubAssign + IsFloat,
T: NativeType
+ std::iter::Sum
+ NumCast
+ Mul<Output = T>
+ AddAssign
+ SubAssign
+ IsFloat
+ Num,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
Expand Down
30 changes: 13 additions & 17 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end || self.last_recompute > 128 {
Expand Down Expand Up @@ -68,7 +68,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<
}
}
self.last_end = end;
self.sum_of_squares
Some(self.sum_of_squares)
}
}

Expand Down Expand Up @@ -108,25 +108,24 @@ impl<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let count: T = NumCast::from(end - start).unwrap();
let sum_of_squares = self.sum_of_squares.update(start, end);
let mean = self.mean.update(start, end);
let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked();
let mean = self.mean.update(start, end).unwrap_unchecked();

let denom = count - NumCast::from(self.ddof).unwrap();
if end - start == 1 {
T::zero()
} else if denom <= T::zero() {
//ddof would be greater than # of observations
T::infinity()
if denom <= T::zero() {
None
} else if end - start == 1 {
Some(T::zero())
} else {
let out = (sum_of_squares - count * mean * mean) / denom;
// variance cannot be negative.
// if it is negative it is due to numeric instability
if out < T::zero() {
T::zero()
Some(T::zero())
} else {
out
Some(out)
}
}
}
Expand Down Expand Up @@ -208,14 +207,11 @@ mod test {

let out = rolling_var(values, 2, 1, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out
.into_iter()
.map(|v| v.copied().unwrap())
.collect::<Vec<_>>();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
// we cannot compare nans, so we compare the string values
assert_eq!(
format!("{:?}", out.as_slice()),
format!("{:?}", &[0.0, 8.0, 2.0, 0.5])
format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
);
// test nan handling.
let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
Expand Down
34 changes: 29 additions & 5 deletions crates/polars-core/src/frame/group_by/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ where
None
} else {
// SAFETY: we are in bounds.
Some(unsafe { agg_window.update(start as usize, end as usize) })
unsafe { agg_window.update(start as usize, end as usize) }
}
})
.collect::<PrimitiveArray<T>>()
Expand Down Expand Up @@ -799,7 +799,13 @@ where
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
Expand Down Expand Up @@ -861,7 +867,13 @@ where
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
Expand Down Expand Up @@ -1012,7 +1024,13 @@ where
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.var(ddof)
Expand Down Expand Up @@ -1054,7 +1072,13 @@ where
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.std(ddof)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ where
} else {
// SAFETY:
// we are in bounds
Some(unsafe { agg_window.update(start as usize, end as usize) })
unsafe { agg_window.update(start as usize, end as usize) }
}
})
})
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6834,7 +6834,7 @@ def rolling_std(
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ null │
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0
│ 1 ┆ 2001-01-01 01:00:00 ┆ null
│ 2 ┆ 2001-01-01 02:00:00 ┆ 0.707107 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 0.707107 │
│ 4 ┆ 2001-01-01 04:00:00 ┆ 0.707107 │
Expand All @@ -6859,7 +6859,7 @@ def rolling_std(
│ --- ┆ --- ┆ --- │
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0
│ 0 ┆ 2001-01-01 00:00:00 ┆ null
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.707107 │
│ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │
Expand Down Expand Up @@ -7081,7 +7081,7 @@ def rolling_var(
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ null │
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0
│ 1 ┆ 2001-01-01 01:00:00 ┆ null
│ 2 ┆ 2001-01-01 02:00:00 ┆ 0.5 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 0.5 │
│ 4 ┆ 2001-01-01 04:00:00 ┆ 0.5 │
Expand All @@ -7106,7 +7106,7 @@ def rolling_var(
│ --- ┆ --- ┆ --- │
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0
│ 0 ┆ 2001-01-01 00:00:00 ┆ null
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.5 │
│ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │
Expand Down
Loading

0 comments on commit 6b23f79

Please sign in to comment.