diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index b00d3f5c34cb..054089ff404f 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -125,6 +125,7 @@ range = ["polars-plan/range"] mode = ["polars-plan/mode"] cum_agg = ["polars-plan/cum_agg"] interpolate = ["polars-plan/interpolate"] +interpolate_by = ["polars-plan/interpolate_by"] rolling_window = [ "polars-plan/rolling_window", ] @@ -270,6 +271,7 @@ features = [ "futures", "hist", "interpolate", + "interpolate_by", "ipc", "is_first_distinct", "is_in", diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 75498da2e313..3bbdb10fcaf0 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -91,6 +91,7 @@ string_encoding = ["base64", "hex"] # ops to_dummies = [] interpolate = [] +interpolate_by = [] list_to_struct = ["polars-core/dtype-struct"] array_to_struct = ["polars-core/dtype-array", "polars-core/dtype-struct"] list_count = [] diff --git a/crates/polars-ops/src/chunked_array/mod.rs b/crates/polars-ops/src/chunked_array/mod.rs index 31729d7c7c67..c0f1941a90dc 100644 --- a/crates/polars-ops/src/chunked_array/mod.rs +++ b/crates/polars-ops/src/chunked_array/mod.rs @@ -3,8 +3,6 @@ pub mod array; mod binary; #[cfg(feature = "timezones")] pub mod datetime; -#[cfg(feature = "interpolate")] -mod interpolate; pub mod list; #[cfg(feature = "propagate_nans")] pub mod nan_propagating_aggregate; @@ -36,8 +34,6 @@ pub use datetime::*; pub use gather::*; #[cfg(feature = "hist")] pub use hist::*; -#[cfg(feature = "interpolate")] -pub use interpolate::*; pub use list::*; #[allow(unused_imports)] use polars_core::prelude::*; diff --git a/crates/polars-ops/src/chunked_array/interpolate.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate.rs similarity index 94% rename from crates/polars-ops/src/chunked_array/interpolate.rs rename to crates/polars-ops/src/series/ops/interpolation/interpolate.rs index 192bcee68ac4..0263b506920d 100644 --- a/crates/polars-ops/src/chunked_array/interpolate.rs +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate.rs @@ -8,26 +8,7 @@ use polars_core::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -fn linear_itp(low: T, step: T, slope: T) -> T -where - T: Sub + Mul + Add + Div, -{ - low + step * slope -} - -fn nearest_itp(low: T, step: T, diff: T, steps_n: T) -> T -where - T: Sub + Mul + Add + Div + PartialOrd + Copy, -{ - // 5 - 1 = 5 -> low - // 5 - 2 = 3 -> low - // 5 - 3 = 2 -> high - if (steps_n - step) > step { - low - } else { - low + diff - } -} +use super::{linear_itp, nearest_itp}; fn near_interp(low: T, high: T, steps: IdxSize, steps_n: T, out: &mut Vec) where @@ -75,11 +56,11 @@ where let first = chunked_arr.first_non_null().unwrap(); let last = chunked_arr.last_non_null().unwrap() + 1; - // Fill out with first. + // Fill out with `first` nulls. let mut out = Vec::with_capacity(chunked_arr.len()); let mut iter = chunked_arr.iter().skip(first); for _ in 0..first { - out.push(Zero::zero()) + out.push(Zero::zero()); } // The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first @@ -109,11 +90,11 @@ where validity.extend_constant(chunked_arr.len(), true); for i in 0..first { - validity.set(i, false); + unsafe { validity.set_unchecked(i, false) }; } for i in last..chunked_arr.len() { - validity.set(i, false); + unsafe { validity.set_unchecked(i, false) }; out.push(Zero::zero()) } diff --git a/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs new file mode 100644 index 000000000000..f425ffaac7e7 --- /dev/null +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs @@ -0,0 +1,332 @@ +use std::ops::{Add, Div, Mul, Sub}; + +use arrow::array::PrimitiveArray; +use arrow::bitmap::MutableBitmap; +use bytemuck::allocation::zeroed_vec; +use polars_core::export::num::{NumCast, Zero}; +use polars_core::prelude::*; +use polars_utils::slice::SliceAble; + +use super::linear_itp; + +/// # Safety +/// - `x` must be non-empty. +#[inline] +unsafe fn signed_interp_by_sorted(y_start: T, y_end: T, x: &[F], out: &mut Vec) +where + T: Sub + + Mul + + Add + + Div + + NumCast + + Copy + + Zero, + F: Sub + NumCast + Copy, +{ + let range_y = y_end - y_start; + let x_start; + let range_x; + let iter; + unsafe { + x_start = x.get_unchecked(0); + range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap(); + iter = x.slice_unchecked(1..x.len() - 1).iter(); + } + let slope = range_y / range_x; + for x_i in iter { + let x_delta = NumCast::from(*x_i - *x_start).unwrap(); + let v = linear_itp(y_start, x_delta, slope); + out.push(v) + } +} + +/// # Safety +/// - `x` must be non-empty. +/// - `sorting_indices` must be the same size as `x` +#[inline] +unsafe fn signed_interp_by( + y_start: T, + y_end: T, + x: &[F], + out: &mut [T], + sorting_indices: &[IdxSize], +) where + T: Sub + + Mul + + Add + + Div + + NumCast + + Copy + + Zero, + F: Sub + NumCast + Copy, +{ + let range_y = y_end - y_start; + let x_start; + let range_x; + let iter; + unsafe { + x_start = x.get_unchecked(0); + range_x = NumCast::from(*x.get_unchecked(x.len() - 1) - *x_start).unwrap(); + iter = x.slice_unchecked(1..x.len() - 1).iter(); + } + let slope = range_y / range_x; + for (idx, x_i) in iter.enumerate() { + let x_delta = NumCast::from(*x_i - *x_start).unwrap(); + let v = linear_itp(y_start, x_delta, slope); + unsafe { + let out_idx = sorting_indices.get_unchecked(idx + 1); + *out.get_unchecked_mut(*out_idx as usize) = v; + } + } +} + +fn interpolate_impl_by_sorted( + chunked_arr: &ChunkedArray, + by: &ChunkedArray, + interpolation_branch: I, +) -> PolarsResult> +where + T: PolarsNumericType, + F: PolarsIntegerType, + I: Fn(T::Native, T::Native, &[F::Native], &mut Vec), +{ + // This implementation differs from pandas as that boundary None's are not removed. + // This prevents a lot of errors due to expressions leading to different lengths. + if !chunked_arr.has_validity() || chunked_arr.null_count() == chunked_arr.len() { + return Ok(chunked_arr.clone()); + } + + polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression"); + let by = by.rechunk(); + let by_values = by.cont_slice().unwrap(); + + // We first find the first and last so that we can set the null buffer. + let first = chunked_arr.first_non_null().unwrap(); + let last = chunked_arr.last_non_null().unwrap() + 1; + + // Fill out with `first` nulls. + let mut out = Vec::with_capacity(chunked_arr.len()); + let mut iter = chunked_arr.iter().enumerate().skip(first); + for _ in 0..first { + out.push(Zero::zero()); + } + + // The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first + // `first` elements and if all values were missing we'd have done an early return. + let (mut low_idx, opt_low) = iter.next().unwrap(); + let mut low = opt_low.unwrap(); + out.push(low); + while let Some((idx, next)) = iter.next() { + if let Some(v) = next { + out.push(v); + low = v; + low_idx = idx; + } else { + for (high_idx, next) in iter.by_ref() { + if let Some(high) = next { + // SAFETY: we are in bounds, and `x` is non-empty. + unsafe { + let x = &by_values.slice_unchecked(low_idx..high_idx + 1); + interpolation_branch(low, high, x, &mut out); + } + out.push(high); + low = high; + low_idx = high_idx; + break; + } + } + } + } + if first != 0 || last != chunked_arr.len() { + let mut validity = MutableBitmap::with_capacity(chunked_arr.len()); + validity.extend_constant(chunked_arr.len(), true); + + for i in 0..first { + unsafe { validity.set_unchecked(i, false) }; + } + + for i in last..chunked_arr.len() { + unsafe { validity.set_unchecked(i, false) } + out.push(Zero::zero()); + } + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(true), + out.into(), + Some(validity.into()), + ); + Ok(ChunkedArray::with_chunk(chunked_arr.name(), array)) + } else { + Ok(ChunkedArray::from_vec(chunked_arr.name(), out)) + } +} + +// Sort on behalf of user +fn interpolate_impl_by( + ca: &ChunkedArray, + by: &ChunkedArray, + interpolation_branch: I, +) -> PolarsResult> +where + T: PolarsNumericType, + F: PolarsIntegerType, + I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]), +{ + // This implementation differs from pandas as that boundary None's are not removed. + // This prevents a lot of errors due to expressions leading to different lengths. + if !ca.has_validity() || ca.null_count() == ca.len() { + return Ok(ca.clone()); + } + + polars_ensure!(by.null_count() == 0, InvalidOperation: "null values in `by` column are not yet supported in 'interpolate_by' expression"); + let sorting_indices = by.arg_sort(Default::default()); + let sorting_indices = sorting_indices + .cont_slice() + .expect("arg sort produces single chunk"); + let by_sorted = unsafe { by.take_unchecked(sorting_indices) }; + let ca_sorted = unsafe { ca.take_unchecked(sorting_indices) }; + let by_sorted_values = by_sorted + .cont_slice() + .expect("We already checked for nulls, and `take_unchecked` produces single chunk"); + + // We first find the first and last so that we can set the null buffer. + let first = ca_sorted.first_non_null().unwrap(); + let last = ca_sorted.last_non_null().unwrap() + 1; + + let mut out = zeroed_vec(ca_sorted.len()); + let mut iter = ca_sorted.iter().enumerate().skip(first); + + // The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first + // `first` elements and if all values were missing we'd have done an early return. + let (mut low_idx, opt_low) = iter.next().unwrap(); + let mut low = opt_low.unwrap(); + unsafe { + let out_idx = sorting_indices.get_unchecked(low_idx); + *out.get_unchecked_mut(*out_idx as usize) = low; + } + while let Some((idx, next)) = iter.next() { + if let Some(v) = next { + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = v; + } + low = v; + low_idx = idx; + } else { + for (high_idx, next) in iter.by_ref() { + if let Some(high) = next { + // SAFETY: we are in bounds, and the slices are the same length (and non-empty). + unsafe { + interpolation_branch( + low, + high, + by_sorted_values.slice_unchecked(low_idx..high_idx + 1), + &mut out, + sorting_indices.slice_unchecked(low_idx..high_idx + 1), + ); + let out_idx = sorting_indices.get_unchecked(high_idx); + *out.get_unchecked_mut(*out_idx as usize) = high; + } + low = high; + low_idx = high_idx; + break; + } + } + } + } + if first != 0 || last != ca_sorted.len() { + let mut validity = MutableBitmap::with_capacity(ca_sorted.len()); + validity.extend_constant(ca_sorted.len(), true); + + for i in 0..first { + unsafe { + let out_idx = sorting_indices.get_unchecked(i); + validity.set_unchecked(*out_idx as usize, false); + } + } + + for i in last..ca_sorted.len() { + unsafe { + let out_idx = sorting_indices.get_unchecked(i); + validity.set_unchecked(*out_idx as usize, false); + } + } + + let array = PrimitiveArray::new( + T::get_dtype().to_arrow(true), + out.into(), + Some(validity.into()), + ); + Ok(ChunkedArray::with_chunk(ca_sorted.name(), array)) + } else { + Ok(ChunkedArray::from_vec(ca_sorted.name(), out)) + } +} + +pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResult { + polars_ensure!(s.len() == by.len(), InvalidOperation: "`by` column must be the same length as Series ({}), got {}", s.len(), by.len()); + + fn func( + ca: &ChunkedArray, + by: &ChunkedArray, + is_sorted: bool, + ) -> PolarsResult + where + T: PolarsNumericType, + F: PolarsIntegerType, + ChunkedArray: IntoSeries, + { + if is_sorted { + interpolate_impl_by_sorted(ca, by, |y_start, y_end, x, out| unsafe { + signed_interp_by_sorted(y_start, y_end, x, out) + }) + .map(|x| x.into_series()) + } else { + interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe { + signed_interp_by(y_start, y_end, x, out, sorting_indices) + }) + .map(|x| x.into_series()) + } + } + + match (s.dtype(), by.dtype()) { + (DataType::Float64, DataType::Int64) => { + func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::Int32) => { + func(s.f64().unwrap(), by.i32().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::UInt64) => { + func(s.f64().unwrap(), by.u64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::UInt32) => { + func(s.f64().unwrap(), by.u32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Int64) => { + func(s.f32().unwrap(), by.i64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Int32) => { + func(s.f32().unwrap(), by.i32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::UInt64) => { + func(s.f32().unwrap(), by.u64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::UInt32) => { + func(s.f32().unwrap(), by.u32().unwrap(), by_is_sorted) + }, + #[cfg(feature = "dtype-date")] + (_, DataType::Date) => interpolate_by(s, &by.cast(&DataType::Int32).unwrap(), by_is_sorted), + #[cfg(feature = "dtype-datetime")] + (_, DataType::Datetime(_, _)) => { + interpolate_by(s, &by.cast(&DataType::Int64).unwrap(), by_is_sorted) + }, + (DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => { + interpolate_by(&s.cast(&DataType::Float64).unwrap(), by, by_is_sorted) + }, + _ => { + polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \ + Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \ + UInt64, or UInt32") + }, + } +} diff --git a/crates/polars-ops/src/series/ops/interpolation/mod.rs b/crates/polars-ops/src/series/ops/interpolation/mod.rs new file mode 100644 index 000000000000..44511ff35b4b --- /dev/null +++ b/crates/polars-ops/src/series/ops/interpolation/mod.rs @@ -0,0 +1,26 @@ +use std::ops::{Add, Div, Mul, Sub}; +#[cfg(feature = "interpolate")] +pub mod interpolate; +#[cfg(feature = "interpolate_by")] +pub mod interpolate_by; + +fn linear_itp(low: T, step: T, slope: T) -> T +where + T: Sub + Mul + Add + Div, +{ + low + step * slope +} + +fn nearest_itp(low: T, step: T, diff: T, steps_n: T) -> T +where + T: Sub + Mul + Add + Div + PartialOrd + Copy, +{ + // 5 - 1 = 5 -> low + // 5 - 2 = 3 -> low + // 5 - 3 = 2 -> high + if (steps_n - step) > step { + low + } else { + low + diff + } +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index a87a9ef9a29d..75c40c6d500d 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -25,6 +25,8 @@ mod fused; mod horizontal; mod index; mod int_range; +#[cfg(any(feature = "interpolate_by", feature = "interpolate"))] +mod interpolation; #[cfg(feature = "is_between")] mod is_between; #[cfg(feature = "is_first_distinct")] @@ -89,6 +91,12 @@ pub use fused::*; pub use horizontal::*; pub use index::*; pub use int_range::*; +#[cfg(feature = "interpolate")] +pub use interpolation::interpolate::*; +#[cfg(feature = "interpolate_by")] +pub use interpolation::interpolate_by::*; +#[cfg(any(feature = "interpolate", feature = "interpolate_by"))] +pub use interpolation::*; #[cfg(feature = "is_between")] pub use is_between::*; #[cfg(feature = "is_first_distinct")] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index cfae4ffa0477..92113dc29b04 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -118,6 +118,7 @@ range = [] mode = ["polars-ops/mode"] cum_agg = ["polars-ops/cum_agg"] interpolate = ["polars-ops/interpolate"] +interpolate_by = ["polars-ops/interpolate_by"] rolling_window = [ "polars-core/rolling_window", "polars-time/rolling_window", @@ -253,6 +254,7 @@ features = [ "peaks", "abs", "interpolate", + "interpolate_by", "list_count", "cum_agg", "top_k", diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index ac9bc04731b9..dd7c6eba0f86 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -24,6 +24,13 @@ pub(super) fn interpolate(s: &Series, method: InterpolationMethod) -> PolarsResu Ok(polars_ops::prelude::interpolate(s, method)) } +#[cfg(feature = "interpolate_by")] +pub(super) fn interpolate_by(s: &[Series]) -> PolarsResult { + let by = &s[1]; + let by_is_sorted = by.is_sorted(Default::default())?; + polars_ops::prelude::interpolate_by(&s[0], by, by_is_sorted) +} + pub(super) fn to_physical(s: &Series) -> PolarsResult { Ok(s.to_physical_repr().into_owned()) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index e8746609c66a..5909779440a9 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -234,6 +234,8 @@ pub enum FunctionExpr { PctChange, #[cfg(feature = "interpolate")] Interpolate(InterpolationMethod), + #[cfg(feature = "interpolate_by")] + InterpolateBy, #[cfg(feature = "log")] Entropy { base: f64, @@ -390,6 +392,8 @@ impl Hash for FunctionExpr { Diff(_, null_behavior) => null_behavior.hash(state), #[cfg(feature = "interpolate")] Interpolate(f) => f.hash(state), + #[cfg(feature = "interpolate_by")] + InterpolateBy => {}, #[cfg(feature = "ffi_plugin")] FfiPlugin { lib, @@ -676,6 +680,8 @@ impl Display for FunctionExpr { PctChange => "pct_change", #[cfg(feature = "interpolate")] Interpolate(_) => "interpolate", + #[cfg(feature = "interpolate_by")] + InterpolateBy => "interpolate_by", #[cfg(feature = "log")] Entropy { .. } => "entropy", #[cfg(feature = "log")] @@ -1009,6 +1015,10 @@ impl From for SpecialEq> { Interpolate(method) => { map!(dispatch::interpolate, method) }, + #[cfg(feature = "interpolate_by")] + InterpolateBy => { + map_as_slice!(dispatch::interpolate_by) + }, #[cfg(feature = "log")] Entropy { base, normalize } => map!(log::entropy, base, normalize), #[cfg(feature = "log")] diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 49d5eb9ecc81..dadfa5560c65 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -172,6 +172,8 @@ impl FunctionExpr { InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(), InterpolationMethod::Nearest => mapper.with_same_dtype(), }, + #[cfg(feature = "interpolate_by")] + InterpolateBy => mapper.map_numeric_to_float_dtype(), ShrinkType => { // we return the smallest type this can return // this might not be correct once the actual data diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 633e14938dcb..2492f887f6a1 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1253,6 +1253,12 @@ impl Expr { ) } + #[cfg(feature = "interpolate_by")] + /// Fill null values using interpolation. + pub fn interpolate_by(self, by: Expr) -> Expr { + self.apply_many_private(FunctionExpr::InterpolateBy, &[by], false, false) + } + #[cfg(feature = "rolling_window")] #[allow(clippy::type_complexity)] fn finish_rolling( diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 9afc4e6ca324..7f88b92ff9e9 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -164,6 +164,7 @@ extract_jsonpath = [ find_many = ["polars-plan/find_many"] fused = ["polars-ops/fused", "polars-lazy?/fused"] interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] +interpolate_by = ["polars-ops/interpolate_by", "polars-lazy?/interpolate_by"] is_between = ["polars-lazy?/is_between", "polars-ops/is_between"] is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"] is_in = ["polars-lazy?/is_in"] @@ -371,6 +372,7 @@ docs-selection = [ "rolling_window", "rolling_window_by", "interpolate", + "interpolate_by", "diff", "rank", "range", diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 1d7f43b8db36..569dea1d6f26 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -245,7 +245,7 @@ //! - `mode` - [Return the most occurring value(s)](polars_ops::chunked_array::mode) //! - `cum_agg` - [`cum_sum`], [`cum_min`], [`cum_max`] aggregation. //! - `rolling_window` - rolling window functions, like [`rolling_mean`] -//! - `interpolate` [interpolate None values](polars_ops::chunked_array::interpolate) +//! - `interpolate` [interpolate None values](polars_ops::series::interpolate()) //! - `extract_jsonpath` - [Run jsonpath queries on StringChunked](https://goessner.net/articles/JsonPath/) //! - `list` - List utils. //! - `list_gather` take sublist by multiple indices diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index e651e60116c8..a602b9691a5e 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -56,6 +56,7 @@ features = [ "ewma_by", "fmt", "interpolate", + "interpolate_by", "is_first_distinct", "is_last_distinct", "is_unique", diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 082a6e9e7e24..4505a19ad7f6 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6216,6 +6216,41 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self: """ return self._from_pyexpr(self._pyexpr.interpolate(method)) + def interpolate_by(self, by: IntoExpr) -> Self: + """ + Fill null values using interpolation based on another column. + + Parameters + ---------- + by + Column to interpolate values based on. + + Examples + -------- + Fill null values using linear interpolation. + + >>> df = pl.DataFrame( + ... { + ... "a": [1, None, None, 3], + ... "b": [1, 2, 7, 8], + ... } + ... ) + >>> df.with_columns(a_interpolated=pl.col("a").interpolate_by("b")) + shape: (4, 3) + ┌──────┬─────┬────────────────┐ + │ a ┆ b ┆ a_interpolated │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ f64 │ + ╞══════╪═════╪════════════════╡ + │ 1 ┆ 1 ┆ 1.0 │ + │ null ┆ 2 ┆ 1.285714 │ + │ null ┆ 7 ┆ 2.714286 │ + │ 3 ┆ 8 ┆ 3.0 │ + └──────┴─────┴────────────────┘ + """ + by = parse_as_expression(by) + return self._from_pyexpr(self._pyexpr.interpolate_by(by)) + @unstable() def rolling_min_by( self, diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index d17a83ee21e3..0cc485504939 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6451,6 +6451,32 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Series: ] """ + def interpolate_by(self, by: IntoExpr) -> Series: + """ + Fill null values using interpolation based on another column. + + Parameters + ---------- + by + Column to interpolate values based on. + + Examples + -------- + Fill null values using linear interpolation. + + >>> s = pl.Series([1, None, None, 3]) + >>> by = pl.Series([1, 2, 7, 8]) + >>> s.interpolate_by(by) + shape: (4,) + Series: '' [f64] + [ + 1.0 + 1.285714 + 2.714286 + 3.0 + ] + """ + def abs(self) -> Series: """ Compute absolute values. @@ -7087,7 +7113,7 @@ def ewm_mean( def ewm_mean_by( self, - by: str | IntoExpr, + by: IntoExpr, *, half_life: str | timedelta, ) -> Series: diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 6bd1a706c9cb..5808a30d9a18 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -780,6 +780,9 @@ impl PyExpr { fn interpolate(&self, method: Wrap) -> Self { self.inner.clone().interpolate(method.0).into() } + fn interpolate_by(&self, by: PyExpr) -> Self { + self.inner.clone().interpolate_by(by.inner).into() + } fn lower_bound(&self) -> Self { self.inner.clone().lower_bound().into() diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index e4edc3323b1a..82b0e5281362 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -975,6 +975,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::Interpolate(_) => { return Err(PyNotImplementedError::new_err("interpolate")) }, + FunctionExpr::InterpolateBy => { + return Err(PyNotImplementedError::new_err("interpolate_by")) + }, FunctionExpr::Entropy { base: _, normalize: _, diff --git a/py-polars/tests/unit/operations/test_interpolate_by.py b/py-polars/tests/unit/operations/test_interpolate_by.py new file mode 100644 index 000000000000..39a293575a5a --- /dev/null +++ b/py-polars/tests/unit/operations/test_interpolate_by.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from datetime import date +from typing import TYPE_CHECKING + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import assume, given + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing.parametric import column, dataframes + +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType + + +@pytest.mark.parametrize( + "times_dtype", + [ + pl.Datetime("ms"), + pl.Datetime("us", "Asia/Kathmandu"), + pl.Datetime("ns"), + pl.Date, + pl.Int64, + pl.Int32, + pl.UInt64, + pl.UInt32, + ], +) +@pytest.mark.parametrize( + "values_dtype", + [ + pl.Float64, + pl.Float32, + pl.Int64, + pl.Int32, + pl.UInt64, + pl.UInt32, + ], +) +def test_interpolate_by( + values_dtype: PolarsDataType, times_dtype: PolarsDataType +) -> None: + df = pl.DataFrame( + { + "times": [ + 1, + 3, + 10, + 11, + 12, + 16, + 21, + 30, + ], + "values": [1, None, None, 5, None, None, None, 6], + }, + schema={"times": times_dtype, "values": values_dtype}, + ) + result = df.select(pl.col("values").interpolate_by("times")) + expected = pl.DataFrame( + { + "values": [ + 1.0, + 1.7999999999999998, + 4.6, + 5.0, + 5.052631578947368, + 5.2631578947368425, + 5.526315789473684, + 6.0, + ] + } + ) + if values_dtype == pl.Float32: + expected = expected.select(pl.col("values").cast(pl.Float32)) + assert_frame_equal(result, expected) + result = ( + df.sort("times", descending=True) + .with_columns(pl.col("values").interpolate_by("times")) + .sort("times") + .drop("times") + ) + assert_frame_equal(result, expected) + + +def test_interpolate_by_leading_nulls() -> None: + df = pl.DataFrame( + { + "times": [ + date(2020, 1, 1), + date(2020, 1, 1), + date(2020, 1, 1), + date(2020, 1, 1), + date(2020, 1, 3), + date(2020, 1, 10), + date(2020, 1, 11), + ], + "values": [None, None, None, 1, None, None, 5], + } + ) + result = df.select(pl.col("values").interpolate_by("times")) + expected = pl.DataFrame( + {"values": [None, None, None, 1.0, 1.7999999999999998, 4.6, 5.0]} + ) + assert_frame_equal(result, expected) + result = ( + df.sort("times", descending=True) + .with_columns(pl.col("values").interpolate_by("times")) + .sort("times") + .drop("times") + ) + assert_frame_equal(result, expected) + + +def test_interpolate_by_trailing_nulls() -> None: + df = pl.DataFrame( + { + "times": [ + date(2020, 1, 1), + date(2020, 1, 3), + date(2020, 1, 10), + date(2020, 1, 11), + date(2020, 1, 12), + date(2020, 1, 13), + ], + "values": [1, None, None, 5, None, None], + } + ) + result = df.select(pl.col("values").interpolate_by("times")) + expected = pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}) + assert_frame_equal(result, expected) + result = ( + df.sort("times", descending=True) + .with_columns(pl.col("values").interpolate_by("times")) + .sort("times") + .drop("times") + ) + assert_frame_equal(result, expected) + + +@given(data=st.data()) +def test_interpolate_vs_numpy(data: st.DataObject) -> None: + dataframe = ( + data.draw( + dataframes( + [ + column( + "ts", + dtype=pl.Date, + allow_null=False, + ), + column( + "value", + dtype=pl.Float64, + allow_null=True, + ), + ], + min_size=1, + ) + ) + .sort("ts") + .fill_nan(None) + .unique("ts") + ) + assume(not dataframe["value"].is_null().all()) + assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any()) + result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"] + + mask = dataframe["value"].is_not_null() + x = dataframe["ts"].to_numpy().astype("int64") + xp = dataframe["ts"].filter(mask).to_numpy().astype("int64") + yp = dataframe["value"].filter(mask).to_numpy().astype("float64") + interp = np.interp(x, xp, yp) + # Polars preserves nulls on boundaries, but NumPy doesn't. + first_non_null = dataframe["value"].is_not_null().arg_max() + last_non_null = len(dataframe) - dataframe["value"][::-1].is_not_null().arg_max() # type: ignore[operator] + interp[:first_non_null] = float("nan") + interp[last_non_null:] = float("nan") + expected = dataframe.with_columns(value=pl.Series(interp, nan_to_null=True))[ + "value" + ] + + assert_series_equal(result, expected) + result_from_unsorted = ( + dataframe.sort("ts", descending=True) + .with_columns(pl.col("value").interpolate_by("ts")) + .sort("ts")["value"] + ) + assert_series_equal(result_from_unsorted, expected) + + +def test_interpolate_by_invalid() -> None: + s = pl.Series([1, None, 3]) + by = pl.Series([1, 2]) + with pytest.raises(pl.InvalidOperationError, match=r"\(3\), got 2"): + s.interpolate_by(by) + + by = pl.Series([1, None, 3]) + with pytest.raises( + pl.InvalidOperationError, + match="null values in `by` column are not yet supported in 'interpolate_by'", + ): + s.interpolate_by(by)