Skip to content

Commit

Permalink
perf: use is_sorted in ewm_mean_by, deprecate check_sorted (pola-rs#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored May 20, 2024
1 parent 7cf55e0 commit 9e3d614
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 61 deletions.
64 changes: 45 additions & 19 deletions crates/polars-ops/src/series/ops/ewm_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,68 @@ pub fn ewm_mean_by(
s: &Series,
times: &Series,
half_life: i64,
assume_sorted: bool,
times_is_sorted: bool,
) -> PolarsResult<Series> {
match (s.dtype(), times.dtype()) {
(DataType::Float64, DataType::Int64) => Ok((if assume_sorted {
ewm_mean_by_impl_sorted(s.f64().unwrap(), times.i64().unwrap(), half_life)
} else {
ewm_mean_by_impl(s.f64().unwrap(), times.i64().unwrap(), half_life)
})
.into_series()),
(DataType::Float32, DataType::Int64) => Ok((if assume_sorted {
ewm_mean_by_impl_sorted(s.f32().unwrap(), times.i64().unwrap(), half_life)
fn func<T>(
values: &ChunkedArray<T>,
times: &Int64Chunked,
half_life: i64,
times_is_sorted: bool,
) -> PolarsResult<Series>
where
T: PolarsFloatType,
T::Native: Float + Zero + One,
ChunkedArray<T>: IntoSeries,
{
if times_is_sorted {
Ok(ewm_mean_by_impl_sorted(values, times, half_life).into_series())
} else {
ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life)
})
.into_series()),
Ok(ewm_mean_by_impl(values, times, half_life).into_series())
}
}

match (s.dtype(), times.dtype()) {
(DataType::Float64, DataType::Int64) => func(
s.f64().unwrap(),
times.i64().unwrap(),
half_life,
times_is_sorted,
),
(DataType::Float32, DataType::Int64) => func(
s.f32().unwrap(),
times.i64().unwrap(),
half_life,
times_is_sorted,
),
#[cfg(feature = "dtype-datetime")]
(_, DataType::Datetime(time_unit, _)) => {
let half_life = adjust_half_life_to_time_unit(half_life, time_unit);
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
ewm_mean_by(
s,
&times.cast(&DataType::Int64)?,
half_life,
times_is_sorted,
)
},
#[cfg(feature = "dtype-date")]
(_, DataType::Date) => ewm_mean_by(
s,
&times.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?,
half_life,
assume_sorted,
times_is_sorted,
),
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => ewm_mean_by(
s,
&times.cast(&DataType::Int64)?,
half_life,
times_is_sorted,
),
(_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => {
ewm_mean_by(s, &times.cast(&DataType::Int64)?, half_life, assume_sorted)
},
(DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => {
ewm_mean_by(
&s.cast(&DataType::Float64)?,
times,
half_life,
assume_sorted,
times_is_sorted,
)
},
_ => {
Expand Down
12 changes: 5 additions & 7 deletions crates/polars-plan/src/dsl/function_expr/ewm_by.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use polars_ops::series::SeriesMethods;

use super::*;

pub(super) fn ewm_mean_by(
s: &[Series],
half_life: Duration,
check_sorted: bool,
) -> PolarsResult<Series> {
pub(super) fn ewm_mean_by(s: &[Series], half_life: Duration) -> PolarsResult<Series> {
let time_zone = match s[1].dtype() {
DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()),
_ => None,
Expand All @@ -15,6 +13,6 @@ pub(super) fn ewm_mean_by(
let half_life = half_life.duration_ns();
let values = &s[0];
let times = &s[1];
let assume_sorted = !check_sorted || times.is_sorted_flag() == IsSorted::Ascending;
polars_ops::prelude::ewm_mean_by(values, times, half_life, assume_sorted)
let times_is_sorted = times.is_sorted(Default::default())?;
polars_ops::prelude::ewm_mean_by(values, times, half_life, times_is_sorted)
}
11 changes: 2 additions & 9 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ pub enum FunctionExpr {
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life: Duration,
check_sorted: bool,
},
#[cfg(feature = "ewma")]
EwmStd {
Expand Down Expand Up @@ -542,10 +541,7 @@ impl Hash for FunctionExpr {
#[cfg(feature = "ewma")]
EwmMean { options } => options.hash(state),
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life,
check_sorted,
} => (half_life, check_sorted).hash(state),
EwmMeanBy { half_life } => (half_life).hash(state),
#[cfg(feature = "ewma")]
EwmStd { options } => options.hash(state),
#[cfg(feature = "ewma")]
Expand Down Expand Up @@ -1118,10 +1114,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "ewma")]
EwmMean { options } => map!(ewm::ewm_mean, options),
#[cfg(feature = "ewma_by")]
EwmMeanBy {
half_life,
check_sorted,
} => map_as_slice!(ewm_by::ewm_mean_by, half_life, check_sorted),
EwmMeanBy { half_life } => map_as_slice!(ewm_by::ewm_mean_by, half_life),
#[cfg(feature = "ewma")]
EwmStd { options } => map!(ewm::ewm_std, options),
#[cfg(feature = "ewma")]
Expand Down
7 changes: 2 additions & 5 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1647,12 +1647,9 @@ impl Expr {

#[cfg(feature = "ewma_by")]
/// Calculate the exponentially-weighted moving average by a time column.
pub fn ewm_mean_by(self, times: Expr, half_life: Duration, check_sorted: bool) -> Self {
pub fn ewm_mean_by(self, times: Expr, half_life: Duration) -> Self {
self.apply_many_private(
FunctionExpr::EwmMeanBy {
half_life,
check_sorted,
},
FunctionExpr::EwmMeanBy { half_life },
&[times],
false,
false,
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10537,7 +10537,7 @@ def ewm_mean_by(
by: str | IntoExpr,
*,
half_life: str | timedelta,
check_sorted: bool = True,
check_sorted: bool | None = None,
) -> Self:
r"""
Calculate time-based exponentially weighted moving average.
Expand Down Expand Up @@ -10587,6 +10587,10 @@ def ewm_mean_by(
Check whether `by` column is sorted.
Incorrectly setting this to `False` will lead to incorrect output.
.. deprecated:: 0.20.27
Sortedness is now verified in a quick manner, you can safely remove
this argument.
Returns
-------
Expr
Expand Down Expand Up @@ -10625,7 +10629,12 @@ def ewm_mean_by(
"""
by = parse_as_expression(by)
half_life = parse_as_duration_string(half_life)
return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life, check_sorted))
if check_sorted is not None:
issue_deprecation_warning(
"`check_sorted` is now deprecated in `ewm_mean_by`, you can safely remove this argument.",
version="0.20.27",
)
return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life))

@deprecate_nonkeyword_arguments(version="0.19.10")
def ewm_std(
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -858,11 +858,11 @@ impl PyExpr {
};
self.inner.clone().ewm_mean(options).into()
}
fn ewm_mean_by(&self, times: PyExpr, half_life: &str, check_sorted: bool) -> Self {
fn ewm_mean_by(&self, times: PyExpr, half_life: &str) -> Self {
let half_life = Duration::parse(half_life);
self.inner
.clone()
.ewm_mean_by(times.inner, half_life, check_sorted)
.ewm_mean_by(times.inner, half_life)
.into()
}

Expand Down
7 changes: 3 additions & 4 deletions py-polars/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,10 +1020,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::TopKBy { sort_options: _ } => {
return Err(PyNotImplementedError::new_err("top_k_by"))
},
FunctionExpr::EwmMeanBy {
half_life: _,
check_sorted: _,
} => return Err(PyNotImplementedError::new_err("ewm_mean_by")),
FunctionExpr::EwmMeanBy { half_life: _ } => {
return Err(PyNotImplementedError::new_err("ewm_mean_by"))
},
},
options: py.None(),
}
Expand Down
4 changes: 1 addition & 3 deletions py-polars/tests/unit/functions/test_ewm_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def test_ewm_by(data: st.DataObject, half_life: int) -> None:
)
)
result = df.with_row_index().select(
pl.col("values").ewm_mean_by(
by="index", half_life=f"{half_life}i", check_sorted=False
)
pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i")
)
expected = df.select(
pl.col("values").ewm_mean(half_life=half_life, ignore_nulls=False, adjust=False)
Expand Down
18 changes: 8 additions & 10 deletions py-polars/tests/unit/operations/test_ewm_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,20 @@ def test_ewma_by_empty() -> None:
assert_frame_equal(result, expected)


def test_ewma_by_warn_if_unsorted() -> None:
def test_ewma_by_if_unsorted() -> None:
df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]})

# Check that with `check_sorted=False`, the user can get incorrect results
# if they really want to.
result = df.select(
pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False),
)
expected = pl.DataFrame({"values": [3.0, 4.0]})
assert_frame_equal(result, expected)

result = df.with_columns(
pl.col("values").ewm_mean_by("by", half_life="2i"),
)
expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]})
assert_frame_equal(result, expected)

with pytest.deprecated_call(match="you can safely remove this argument"):
result = df.with_columns(
pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False),
)
assert_frame_equal(result, expected)

result = df.sort("by").with_columns(
pl.col("values").ewm_mean_by("by", half_life="2i"),
)
Expand Down

0 comments on commit 9e3d614

Please sign in to comment.