From e76170a9739fc3f9d72c263d69499c6394c3174f Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 24 May 2024 18:25:39 +0200 Subject: [PATCH] fix bug with nulls at start/end --- .../src/chunked_array/ops/fill_null.rs | 8 +++++--- py-polars/tests/unit/dataframe/test_df.py | 18 ++++++++++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 88dafe4edd58..f09852f94f5e 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -299,7 +299,7 @@ where .collect_trusted(); // Compute bitmask. - let num_start_nulls = ca.first_non_null().unwrap_or(0); + let num_start_nulls = ca.first_non_null().unwrap_or(ca.len()); let mut bm = MutableBitmap::with_capacity(ca.len()); bm.extend_constant(num_start_nulls, false); bm.extend_constant(ca.len() - num_start_nulls, true); @@ -330,8 +330,10 @@ where .collect_reversed(); // Compute bitmask. - let last_idx = ca.len().saturating_sub(1); - let num_end_nulls = last_idx - ca.last_non_null().unwrap_or(last_idx); + let num_end_nulls = ca + .last_non_null() + .map(|i| ca.len() - 1 - i) + .unwrap_or(ca.len()); let mut bm = MutableBitmap::with_capacity(ca.len()); bm.extend_constant(ca.len() - num_end_nulls, true); bm.extend_constant(num_end_nulls, false); diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 02377153c863..ce36e51a45e2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1866,10 +1866,24 @@ def test_fill_nan() -> None: assert df.fill_nan(2.0).dtypes == [pl.Float64, pl.Datetime] +def test_forward_fill() -> None: + df = pl.DataFrame({"a": [1.0, None, 3.0]}) + fill = df.select(pl.col("a").forward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [1, 1, 3]).cast(pl.Float64)) + + df = pl.DataFrame({"a": [None, 1, None]}) + fill = df.select(pl.col("a").forward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [None, 1, 1]).cast(pl.Int64)) + + def test_backward_fill() -> None: df = pl.DataFrame({"a": [1.0, None, 3.0]}) - col_a_backward_fill = df.select([pl.col("a").backward_fill()])["a"] - assert_series_equal(col_a_backward_fill, pl.Series("a", [1, 3, 3]).cast(pl.Float64)) + fill = df.select(pl.col("a").backward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [1, 3, 3]).cast(pl.Float64)) + + df = pl.DataFrame({"a": [None, 1, None]}) + fill = df.select(pl.col("a").backward_fill())["a"] + assert_series_equal(fill, pl.Series("a", [1, 1, None]).cast(pl.Int64)) def test_shrink_to_fit() -> None: