From b91f064cf391e93ec9c619b6bc9f0b7dd1c14272 Mon Sep 17 00:00:00 2001 From: thalassemia <67928790+thalassemia@users.noreply.github.com> Date: Wed, 8 May 2024 17:33:39 -0700 Subject: [PATCH] feat(rust): Better testing of hybrid RLE encoder --- .github/workflows/test-rust.yml | 2 + py-polars/tests/unit/io/test_parquet.py | 66 ++++++++++++++++--------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 6db364e210d15..4e54ca0cf8e99 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -52,6 +52,7 @@ jobs: -p polars-io -p polars-lazy -p polars-ops + -p polars-parquet -p polars-plan -p polars-row -p polars-sql @@ -68,6 +69,7 @@ jobs: -p polars-io -p polars-lazy -p polars-ops + -p polars-parquet -p polars-plan -p polars-row -p polars-sql diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 846b4252e5483..dcd2f15a28916 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -896,36 +896,58 @@ def test_no_glob_windows(tmp_path: Path) -> None: @pytest.mark.slow() def test_hybrid_rle() -> None: + # 10_007 elements to test if not a nice multiple of 8 + n = 10_007 + literal_literal = [] + literal_rle = [] + for i in range(500): + literal_literal.append(np.repeat(i, 5)) + literal_literal.append(np.repeat(i + 2, 15)) + literal_rle.append(np.repeat(i, 5)) + literal_rle.append(np.repeat(i + 2, 11)) + literal_literal.append(np.random.randint(0, 10, size=7)) + literal_rle.append(np.random.randint(0, 10, size=2007)) + literal_literal = np.concatenate(literal_literal) + literal_rle = np.concatenate(literal_rle) df = pl.DataFrame( { - # Test primitive types - "i64": pl.repeat(int(2**63 - 1), n=10000, dtype=pl.Int64, eager=True), - "u64": pl.repeat(int(2**64 - 1), n=10000, dtype=pl.UInt64, eager=True), - "i8": pl.repeat(-int(2**7 - 1), n=10000, dtype=pl.Int8, eager=True), - "u8": pl.repeat(int(2**8 - 1), n=10000, dtype=pl.UInt8, eager=True), - "string": pl.repeat("a", n=10000, dtype=pl.String, eager=True), - "categorical": pl.Series((["a"] * 9 + ["b"]) * 1000, dtype=pl.Categorical), - # Test filling up bit-packing buffer - "large_bit_pack": ([0] * 5 + [1] * 5) * 1000, - # Test mix of bit-packed and RLE runs - "bit_pack_and_rle": ( - [0] + [1] * 19 + [2] * 8 + [3] * 12 + [4] * 5 + [5] * 5 - ) - * 200, - # Test some null values - "nulls_included": ( - [None] + [1] * 19 + [None] * 8 + [3] * 12 + [4] * 5 + [None] * 5 - ) - * 200, + # Primitive types + "i64": pl.Series([1, 2], dtype=pl.Int64).sample(n, with_replacement=True), + "u64": pl.Series([1, 2], dtype=pl.UInt64).sample(n, with_replacement=True), + "i8": pl.Series([1, 2], dtype=pl.Int8).sample(n, with_replacement=True), + "u8": pl.Series([1, 2], dtype=pl.UInt8).sample(n, with_replacement=True), + "string": pl.Series(["abc", "def"], dtype=pl.String).sample( + n, with_replacement=True + ), + "categorical": pl.Series(["aaa", "bbb"], dtype=pl.Categorical).sample( + n, with_replacement=True + ), + # Fill up bit-packing buffer in middle of consecutive run + "large_bit_pack": np.concatenate( + [np.repeat(i, 5) for i in range(2000)] + + [np.random.randint(0, 10, size=7)] + ), + # Literal run that is not a multiple of 8 followed by consecutive + # run initially long enough to RLE but not after padding literal + "literal_literal": literal_literal, + # Literal run that is not a multiple of 8 followed by consecutive + # run long enough to RLE even after padding literal + "literal_rle": literal_rle, + # Final run not long enough to RLE + "final_literal": np.concatenate( + [np.random.randint(0, 100, 10_000), np.repeat(-1, 7)] + ), + # Final run long enough to RLE + "final_rle": np.concatenate( + [np.random.randint(0, 100, 9_998), np.repeat(-1, 9)] + ), # Test filling up bit-packing buffer for encode_bool, # which is only used to encode validities - # Also checks that runs are handled correctly if buffer - # is flushed (at MAX_VALUES_PER_LITERAL_RUN values) "large_bit_pack_validity": [0, None] * 4092 + [0] * 9 + [1] * 9 + [2] * 10 - + [0] * 1788, + + [0] * 1795, } ) f = io.BytesIO()