diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index 11e1086afbc7d..237af089a4e9d 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -1,11 +1,11 @@ from __future__ import annotations -import random +import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Collection, Sequence, overload import hypothesis.strategies as st -from hypothesis.errors import InvalidArgument +from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning import polars.functions as F from polars.dataframe import DataFrame @@ -376,11 +376,11 @@ def dataframes( # noqa: D417 cols = [column() for _ in range(cols)] elif isinstance(cols, column): cols = [cols] - elif not isinstance(cols, list): + else: cols = list(cols) if include_cols: - cols.extend(include_cols) + cols.extend(list(include_cols)) if size is None: size = draw(st.integers(min_value=min_size, max_value=max_size)) @@ -534,7 +534,9 @@ def columns( """ # create/assign named columns if cols is None: - cols = random.randint(min_cols, max_cols) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=NonInteractiveExampleWarning) + cols = st.integers(min_value=min_cols, max_value=max_cols).example() if isinstance(cols, int): names: Sequence[str] = [f"col{n}" for n in range(cols)] else: diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py index c2d8c0163add8..16732a07bf7d1 100644 --- a/py-polars/polars/testing/parametric/strategies.py +++ b/py-polars/polars/testing/parametric/strategies.py @@ -79,7 +79,7 @@ def dtypes( allowed_dtypes: Sequence[PolarsDataType] | None = None, excluded_dtypes: Collection[PolarsDataType] | None = None, ) -> DataType: - """Returns a strategy which generates Polars data types.""" + """Returns a strategy which generates Polars `DataType` objects.""" if allowed_dtypes is None: allowed_dtypes = STRATEGY_DTYPES if excluded_dtypes: diff --git a/py-polars/tests/parametric/test_testing.py b/py-polars/tests/parametric/test_testing.py index 7297fe5165324..dea08008783ec 100644 --- a/py-polars/tests/parametric/test_testing.py +++ b/py-polars/tests/parametric/test_testing.py @@ -7,10 +7,10 @@ from datetime import datetime from typing import Any +import hypothesis.strategies as st import pytest from hypothesis import given, settings from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning -from hypothesis.strategies import sampled_from import polars as pl from polars.datatypes import TEMPORAL_DTYPES @@ -21,6 +21,7 @@ dataframes, series, ) +from polars.testing.parametric.primitives import MAX_COLS # TODO: add parametric strategy generator that supports timezones TEMPORAL_DTYPES_ = { @@ -59,8 +60,9 @@ def test_strategy_shape( assert s1.name == "" assert s2.name == "col" - from polars.testing.parametric.primitives import MAX_COLS +@pytest.mark.hypothesis() +def test_columns_auto_infer() -> None: assert 0 <= len(columns(None)) <= MAX_COLS @@ -73,7 +75,7 @@ def test_strategy_shape( cols=columns(["a", "b"], dtype=pl.UInt8, unique=True), include_cols=[ column("c", dtype=pl.Boolean), - column("d", strategy=sampled_from(["x", "y", "z"])), + column("d", strategy=st.sampled_from(["x", "y", "z"])), ], ) ) @@ -256,7 +258,7 @@ def test_invalid_arguments() -> None: column("colx", dtype=pl.Struct) with pytest.raises(InvalidArgument, match="unable to determine dtype"): - column("colx", strategy=sampled_from([None])) + column("colx", strategy=st.none()) with pytest.raises(InvalidArgument, match="not a valid polars datatype"): columns(["colx", "coly"], dtype=pl.DataFrame) # type: ignore[arg-type]