From 2e0064721a299ed8bad9e89ba2adc4912c792a65 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 13 May 2024 13:10:30 +0200 Subject: [PATCH] feat(python): Support `Enum` types in parametric testing (#16188) --- .../polars/testing/parametric/strategies/data.py | 16 +++++++++++++++- .../testing/parametric/strategies/dtype.py | 9 +++++++++ .../testing/parametric/strategies/test_core.py | 7 +++++++ .../testing/parametric/strategies/test_data.py | 10 ++++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 78dc119c8929..7cfda456c390 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -33,6 +33,7 @@ Datetime, Decimal, Duration, + Enum, Float32, Float64, Int8, @@ -49,7 +50,10 @@ UInt64, ) from polars.testing.parametric.strategies._utils import flexhash -from polars.testing.parametric.strategies.dtype import _DEFAULT_ARRAY_WIDTH_LIMIT +from polars.testing.parametric.strategies.dtype import ( + _DEFAULT_ARRAY_WIDTH_LIMIT, + _DEFAULT_ENUM_CATEGORIES_LIMIT, +) if TYPE_CHECKING: from datetime import date, datetime, time @@ -329,6 +333,16 @@ def data( strategy = categories( n_categories=kwargs.pop("n_categories", _DEFAULT_N_CATEGORIES) ) + elif dtype == Enum: + if isinstance(dtype, Enum): + if (cats := dtype.categories).is_empty(): + strategy = nulls() + else: + strategy = st.sampled_from(cats.to_list()) + else: + strategy = categories( + n_categories=kwargs.pop("n_categories", _DEFAULT_ENUM_CATEGORIES_LIMIT) + ) elif dtype == Decimal: strategy = decimals( getattr(dtype, "precision", None), getattr(dtype, "scale", 0) diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index ae54a5e405b4..91bae2317b93 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -15,6 +15,7 @@ Datetime, Decimal, Duration, + Enum, Float32, Float64, Int8, @@ -63,6 +64,7 @@ Duration, Categorical, Decimal, + Enum, ] # Supported data type classes that contain other data types _NESTED_DTYPES: list[DataTypeClass] = [ @@ -76,6 +78,7 @@ _DEFAULT_ARRAY_WIDTH_LIMIT = 3 _DEFAULT_STRUCT_FIELDS_LIMIT = 3 +_DEFAULT_ENUM_CATEGORIES_LIMIT = 3 def dtypes( @@ -174,6 +177,12 @@ def _instantiate_flat_dtype(draw: DrawFn, dtype: PolarsDataType) -> DataType: elif dtype == Categorical: ordering = draw(_categorical_orderings()) return Categorical(ordering) + elif dtype == Enum: + n_categories = draw( + st.integers(min_value=1, max_value=_DEFAULT_ENUM_CATEGORIES_LIMIT) + ) + categories = [f"c{i}" for i in range(n_categories)] + return Enum(categories) elif dtype == Decimal: precision = draw(st.integers(min_value=1, max_value=38) | st.none()) scale = draw(st.integers(min_value=0, max_value=precision or 38)) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index f32807f52452..13a7c31cb291 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -39,6 +39,13 @@ def test_series_dtype(data: st.DataObject) -> None: assert s.dtype == dtype +@given(s=series(dtype=pl.Enum)) +@settings(max_examples=5) +def test_series_dtype_enum(s: pl.Series) -> None: + assert isinstance(s.dtype, pl.Enum) + assert all(v in s.dtype.categories for v in s) + + @given(s=series(dtype=pl.Boolean, size=5)) @settings(max_examples=5) def test_series_size(s: pl.Series) -> None: diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_data.py b/py-polars/tests/unit/testing/parametric/strategies/test_data.py index 0820015158dc..a7316b0a7a7b 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_data.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_data.py @@ -19,3 +19,13 @@ def test_data_kwargs(cat: str) -> None: @given(categories=data(pl.List(pl.Categorical), n_categories=3)) def test_data_nested_kwargs(categories: list[str]) -> None: assert all(c in ("c0", "c1", "c2") for c in categories) + + +@given(cat=data(pl.Enum)) +def test_data_enum(cat: str) -> None: + assert cat in ("c0", "c1", "c2") + + +@given(cat=data(pl.Enum(["hello", "world"]))) +def test_data_enum_instantiated(cat: str) -> None: + assert cat in ("hello", "world")