From 7fa8a3ebe2ebc1009d877000cb40d7afd8821997 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 3 Apr 2024 17:24:47 +0800 Subject: [PATCH] fix(python): Raise if pass a negative `n` into `clear` (#15432) Co-authored-by: Stijn de Gooijer --- py-polars/polars/dataframe/frame.py | 17 +++++++++-------- py-polars/polars/series/series.py | 4 ++++ py-polars/tests/unit/operations/test_clear.py | 10 ++++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index a34c4aae5b84..8b9caf5a3a85 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -6912,17 +6912,18 @@ def clear(self, n: int = 0) -> Self: │ null ┆ null ┆ null │ └──────┴──────┴──────┘ """ + if n < 0: + msg = f"`n` should be greater than or equal to 0, got {n}" + raise ValueError(msg) # faster path if n == 0: return self._from_pydf(self._df.clear()) - if n > 0 or len(self) > 0: - return self.__class__( - { - nm: pl.Series(name=nm, dtype=tp).extend_constant(None, n) - for nm, tp in self.schema.items() - } - ) - return self.clone() + return self.__class__( + { + nm: pl.Series(name=nm, dtype=tp).extend_constant(None, n) + for nm, tp in self.schema.items() + } + ) def clone(self) -> Self: """ diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a651f8f77508..74834ecde8c8 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4759,6 +4759,10 @@ def clear(self, n: int = 0) -> Series: null ] """ + if n < 0: + msg = f"`n` should be greater than or equal to 0, got {n}" + raise ValueError(msg) + # faster path if n == 0: return self._from_pyseries(self._s.clear()) s = ( diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py index 0ac3c1d27ba0..c9a2d29c1492 100644 --- a/py-polars/tests/unit/operations/test_clear.py +++ b/py-polars/tests/unit/operations/test_clear.py @@ -73,3 +73,13 @@ def test_clear_series_object_starting_with_null() -> None: assert result.dtype == s.dtype assert result.name == s.name assert result.is_empty() + + +def test_clear_raise_negative_n() -> None: + s = pl.Series([1, 2, 3]) + + msg = "`n` should be greater than or equal to 0, got -1" + with pytest.raises(ValueError, match=msg): + s.clear(-1) + with pytest.raises(ValueError, match=msg): + s.to_frame().clear(-1)