Skip to content

Commit

Permalink
feat(python): Support Enum types in parametric testing (#16188)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 13, 2024
1 parent 54ddfa1 commit 2e00647
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
16 changes: 15 additions & 1 deletion py-polars/polars/testing/parametric/strategies/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Datetime,
Decimal,
Duration,
Enum,
Float32,
Float64,
Int8,
Expand All @@ -49,7 +50,10 @@
UInt64,
)
from polars.testing.parametric.strategies._utils import flexhash
from polars.testing.parametric.strategies.dtype import _DEFAULT_ARRAY_WIDTH_LIMIT
from polars.testing.parametric.strategies.dtype import (
_DEFAULT_ARRAY_WIDTH_LIMIT,
_DEFAULT_ENUM_CATEGORIES_LIMIT,
)

if TYPE_CHECKING:
from datetime import date, datetime, time
Expand Down Expand Up @@ -329,6 +333,16 @@ def data(
strategy = categories(
n_categories=kwargs.pop("n_categories", _DEFAULT_N_CATEGORIES)
)
elif dtype == Enum:
if isinstance(dtype, Enum):
if (cats := dtype.categories).is_empty():
strategy = nulls()
else:
strategy = st.sampled_from(cats.to_list())
else:
strategy = categories(
n_categories=kwargs.pop("n_categories", _DEFAULT_ENUM_CATEGORIES_LIMIT)
)
elif dtype == Decimal:
strategy = decimals(
getattr(dtype, "precision", None), getattr(dtype, "scale", 0)
Expand Down
9 changes: 9 additions & 0 deletions py-polars/polars/testing/parametric/strategies/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Datetime,
Decimal,
Duration,
Enum,
Float32,
Float64,
Int8,
Expand Down Expand Up @@ -63,6 +64,7 @@
Duration,
Categorical,
Decimal,
Enum,
]
# Supported data type classes that contain other data types
_NESTED_DTYPES: list[DataTypeClass] = [
Expand All @@ -76,6 +78,7 @@

_DEFAULT_ARRAY_WIDTH_LIMIT = 3
_DEFAULT_STRUCT_FIELDS_LIMIT = 3
_DEFAULT_ENUM_CATEGORIES_LIMIT = 3


def dtypes(
Expand Down Expand Up @@ -174,6 +177,12 @@ def _instantiate_flat_dtype(draw: DrawFn, dtype: PolarsDataType) -> DataType:
elif dtype == Categorical:
ordering = draw(_categorical_orderings())
return Categorical(ordering)
elif dtype == Enum:
n_categories = draw(
st.integers(min_value=1, max_value=_DEFAULT_ENUM_CATEGORIES_LIMIT)
)
categories = [f"c{i}" for i in range(n_categories)]
return Enum(categories)
elif dtype == Decimal:
precision = draw(st.integers(min_value=1, max_value=38) | st.none())
scale = draw(st.integers(min_value=0, max_value=precision or 38))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def test_series_dtype(data: st.DataObject) -> None:
assert s.dtype == dtype


@given(s=series(dtype=pl.Enum))
@settings(max_examples=5)
def test_series_dtype_enum(s: pl.Series) -> None:
assert isinstance(s.dtype, pl.Enum)
assert all(v in s.dtype.categories for v in s)


@given(s=series(dtype=pl.Boolean, size=5))
@settings(max_examples=5)
def test_series_size(s: pl.Series) -> None:
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/testing/parametric/strategies/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ def test_data_kwargs(cat: str) -> None:
@given(categories=data(pl.List(pl.Categorical), n_categories=3))
def test_data_nested_kwargs(categories: list[str]) -> None:
assert all(c in ("c0", "c1", "c2") for c in categories)


@given(cat=data(pl.Enum))
def test_data_enum(cat: str) -> None:
assert cat in ("c0", "c1", "c2")


@given(cat=data(pl.Enum(["hello", "world"])))
def test_data_enum_instantiated(cat: str) -> None:
assert cat in ("hello", "world")

0 comments on commit 2e00647

Please sign in to comment.