Skip to content

Commit

Permalink
Fix use of columns
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed May 5, 2024
1 parent 00066fe commit c14760a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
12 changes: 7 additions & 5 deletions py-polars/polars/testing/parametric/primitives.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import random
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Collection, Sequence, overload

import hypothesis.strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning

import polars.functions as F
from polars.dataframe import DataFrame
Expand Down Expand Up @@ -376,11 +376,11 @@ def dataframes( # noqa: D417
cols = [column() for _ in range(cols)]
elif isinstance(cols, column):
cols = [cols]
elif not isinstance(cols, list):
else:
cols = list(cols)

if include_cols:
cols.extend(include_cols)
cols.extend(list(include_cols))

if size is None:
size = draw(st.integers(min_value=min_size, max_value=max_size))
Expand Down Expand Up @@ -534,7 +534,9 @@ def columns(
"""
# create/assign named columns
if cols is None:
cols = random.randint(min_cols, max_cols)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NonInteractiveExampleWarning)
cols = st.integers(min_value=min_cols, max_value=max_cols).example()
if isinstance(cols, int):
names: Sequence[str] = [f"col{n}" for n in range(cols)]
else:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/testing/parametric/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def dtypes(
allowed_dtypes: Sequence[PolarsDataType] | None = None,
excluded_dtypes: Collection[PolarsDataType] | None = None,
) -> DataType:
"""Returns a strategy which generates Polars data types."""
"""Returns a strategy which generates Polars `DataType` objects."""
if allowed_dtypes is None:
allowed_dtypes = STRATEGY_DTYPES
if excluded_dtypes:
Expand Down
10 changes: 6 additions & 4 deletions py-polars/tests/parametric/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from datetime import datetime
from typing import Any

import hypothesis.strategies as st
import pytest
from hypothesis import given, settings
from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning
from hypothesis.strategies import sampled_from

import polars as pl
from polars.datatypes import TEMPORAL_DTYPES
Expand All @@ -21,6 +21,7 @@
dataframes,
series,
)
from polars.testing.parametric.primitives import MAX_COLS

# TODO: add parametric strategy generator that supports timezones
TEMPORAL_DTYPES_ = {
Expand Down Expand Up @@ -59,8 +60,9 @@ def test_strategy_shape(
assert s1.name == ""
assert s2.name == "col"

from polars.testing.parametric.primitives import MAX_COLS

@pytest.mark.hypothesis()
def test_columns_auto_infer() -> None:
assert 0 <= len(columns(None)) <= MAX_COLS


Expand All @@ -73,7 +75,7 @@ def test_strategy_shape(
cols=columns(["a", "b"], dtype=pl.UInt8, unique=True),
include_cols=[
column("c", dtype=pl.Boolean),
column("d", strategy=sampled_from(["x", "y", "z"])),
column("d", strategy=st.sampled_from(["x", "y", "z"])),
],
)
)
Expand Down Expand Up @@ -256,7 +258,7 @@ def test_invalid_arguments() -> None:
column("colx", dtype=pl.Struct)

with pytest.raises(InvalidArgument, match="unable to determine dtype"):
column("colx", strategy=sampled_from([None]))
column("colx", strategy=st.none())

with pytest.raises(InvalidArgument, match="not a valid polars datatype"):
columns(["colx", "coly"], dtype=pl.DataFrame) # type: ignore[arg-type]
Expand Down

0 comments on commit c14760a

Please sign in to comment.