From 2774bbbf0b59fb6b2def41fb1099848fca9af7df Mon Sep 17 00:00:00 2001 From: Petros Barbagiannis Date: Tue, 26 Mar 2024 16:21:00 +0200 Subject: [PATCH] fix(python): Handle special case correctly when slicing a `LazyFrame` (#15297) --- py-polars/polars/slice.py | 10 ++++++++-- py-polars/tests/unit/operations/test_slice.py | 3 +++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/slice.py b/py-polars/polars/slice.py index 5ee586cf5e5e..b2a9f370048e 100644 --- a/py-polars/polars/slice.py +++ b/py-polars/polars/slice.py @@ -158,7 +158,7 @@ def apply(self, s: slice) -> LazyFrame: # [::k] => gather_every(k), # [::-1] => reverse(), # [::-k] => reverse().gather_every(abs(k)) - elif start == 0 and s.stop is None: + elif s.start is None and s.stop is None: if step == 1: return self.obj.clone() elif step > 1: @@ -168,7 +168,13 @@ def apply(self, s: slice) -> LazyFrame: elif step < -1: return self.obj.reverse().gather_every(abs(step)) - elif start > 0 > step and s.stop is None: + # --------------------------------------- + # straight-through mappings for "head", + # "reverse" and "gather_every" + # --------------------------------------- + # [i::-1] => head(i+1).reverse() + # [i::k], k<-1 => head(i+1).reverse().gather_every(abs(k)) + elif start >= 0 > step and s.stop is None: obj = self.obj.head(s.start + 1).reverse() return obj if (abs(step) == 1) else obj.gather_every(abs(step)) diff --git a/py-polars/tests/unit/operations/test_slice.py b/py-polars/tests/unit/operations/test_slice.py index b9f7152b87d4..1e5f652c7882 100644 --- a/py-polars/tests/unit/operations/test_slice.py +++ b/py-polars/tests/unit/operations/test_slice.py @@ -88,10 +88,13 @@ def test_python_slicing_lazy_frame() -> None: slice(None, 2, 2), slice(3, None, -1), slice(1, None, -2), + slice(0, None, -1), ): # confirm frame slice matches python slice assert ldf[py_slice].collect().rows() == ldf.collect().rows()[py_slice] + assert_frame_equal(ldf[0::-1], ldf.head(1)) + assert_frame_equal(ldf[2::-1], ldf.head(3).reverse()) assert_frame_equal(ldf[::-1], ldf.reverse()) assert_frame_equal(ldf[::-2], ldf.reverse().gather_every(2))