Skip to content

Commit

Permalink
fix(python): Fix some issues in parametric testing with nested dtypes (
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 14, 2024
1 parent ae4e71b commit f305c8c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 47 deletions.
29 changes: 15 additions & 14 deletions py-polars/polars/testing/parametric/strategies/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def lists(
inner_dtype: DataType,
*,
select_from: Sequence[Any] | None = None,
min_len: int = 0,
max_len: int | None = None,
min_size: int = 0,
max_size: int | None = None,
unique: bool = False,
**kwargs: Any,
) -> SearchStrategy[list[Any]]:
Expand All @@ -242,12 +242,12 @@ def lists(
select_from
The values to use for the innermost lists. If set to `None` (default),
the default strategy associated with the innermost data type is used.
min_len
min_size
The minimum length of the generated lists.
max_len
max_size
The maximum length of the generated lists. If set to `None` (default), the
maximum is set based on `min_size`: `3` if `min_len` is zero,
otherwise `2 * min_len`.
maximum is set based on `min_size`: `3` if `min_size` is zero,
otherwise `2 * min_size`.
unique
Ensure that the generated lists contain unique values.
**kwargs
Expand All @@ -257,25 +257,25 @@ def lists(
--------
...
"""
if max_len is None:
max_len = _DEFAULT_LIST_LEN_LIMIT if min_len == 0 else min_len * 2
if max_size is None:
max_size = _DEFAULT_LIST_LEN_LIMIT if min_size == 0 else min_size * 2

if select_from is not None and not inner_dtype.is_nested():
inner_strategy = st.sampled_from(select_from)
else:
inner_strategy = data(
inner_dtype,
select_from=select_from,
min_size=min_len,
max_size=max_len,
min_size=min_size,
max_size=max_size,
unique=unique,
**kwargs,
)

return st.lists(
elements=inner_strategy,
min_size=min_len,
max_size=max_len,
min_size=min_size,
max_size=max_size,
unique_by=(flexhash if unique else None),
)

Expand Down Expand Up @@ -377,10 +377,11 @@ def data(
elif dtype == Array:
inner = getattr(dtype, "inner", None) or Null()
width = getattr(dtype, "width", _DEFAULT_ARRAY_WIDTH_LIMIT)
kwargs = {k: v for k, v in kwargs.items() if k not in ("min_size", "max_size")}
strategy = lists(
inner,
min_len=width,
max_len=width,
min_size=width,
max_size=width,
allow_null=allow_null,
**kwargs,
)
Expand Down
81 changes: 51 additions & 30 deletions py-polars/polars/testing/parametric/strategies/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Decimal,
Duration,
Enum,
Field,
Float32,
Float64,
Int8,
Expand Down Expand Up @@ -222,8 +223,7 @@ def _instantiate_nested_dtype(
) -> DataType:
"""Take a nested data type and instantiate it."""

def instantiate_inner(dtype: PolarsDataType) -> DataType:
inner_dtype = getattr(dtype, "inner", None)
def instantiate_inner(inner_dtype: PolarsDataType | None) -> DataType:
if inner_dtype is None:
return draw(inner)
elif inner_dtype.is_nested():
Expand All @@ -232,24 +232,25 @@ def instantiate_inner(dtype: PolarsDataType) -> DataType:
return draw(_instantiate_flat_dtype(inner_dtype))

if dtype == List:
inner_dtype = instantiate_inner(dtype)
inner_dtype = instantiate_inner(getattr(dtype, "inner", None))
return List(inner_dtype)
elif dtype == Array:
inner_dtype = instantiate_inner(dtype)
inner_dtype = instantiate_inner(getattr(dtype, "inner", None))
width = getattr(
dtype,
"width",
draw(st.integers(min_value=1, max_value=_DEFAULT_ARRAY_WIDTH_LIMIT)),
)
return Array(inner_dtype, width)
elif dtype == Struct:
# TODO: Recursively instantiate struct field dtypes
if isinstance(dtype, DataType):
return dtype
n_fields = draw(
st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT)
)
return Struct({f"f{i}": draw(inner) for i in range(n_fields)})
if isinstance(dtype, Struct):
fields = [Field(f.name, instantiate_inner(f.dtype)) for f in dtype.fields]
else:
n_fields = draw(
st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT)
)
fields = [Field(f"f{i}", draw(inner)) for i in range(n_fields)]
return Struct(fields)
else:
msg = f"unsupported data type: {dtype}"
raise InvalidArgument(msg)
Expand All @@ -276,56 +277,76 @@ def _instantiate_dtype(
) -> DataType:
"""Take a data type and instantiate it."""
if not dtype.is_nested():
if isinstance(dtype, DataType):
return dtype

if allowed_dtypes is None:
allowed_dtypes = [dtype]
else:
allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype]
same_dtypes = [dt for dt in allowed_dtypes if dt == dtype]
allowed_dtypes = same_dtypes if same_dtypes else [dtype]

return draw(
_flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes)
)

def draw_inner(dtype: PolarsDataType) -> DataType:
if isinstance(dtype, DataType):
def draw_inner(dtype: PolarsDataType | None) -> DataType:
if dtype is None:
return draw(
_instantiate_dtype(
dtype.inner, # type: ignore[attr-defined]
dtypes(
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
)
else:
return draw(
dtypes(
_instantiate_dtype(
dtype,
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
)

if dtype == List:
inner = draw_inner(dtype)
inner = draw_inner(getattr(dtype, "inner", None))
return List(inner)
elif dtype == Array:
inner = draw_inner(dtype)
inner = draw_inner(getattr(dtype, "inner", None))
width = getattr(
dtype,
"width",
draw(st.integers(min_value=1, max_value=_DEFAULT_ARRAY_WIDTH_LIMIT)),
)
return Array(inner, width)
elif dtype == Struct:
if isinstance(dtype, DataType):
return dtype
n_fields = draw(
st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT)
)
inner_strategy = dtypes(
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
return Struct({f"f{i}": draw(inner_strategy) for i in range(n_fields)})
if isinstance(dtype, Struct):
fields = [
Field(
name=f.name,
dtype=draw(
_instantiate_dtype(
f.dtype,
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
),
)
for f in dtype.fields
]
else:
n_fields = draw(
st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT)
)
inner_strategy = dtypes(
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
fields = [Field(f"f{i}", draw(inner_strategy)) for i in range(n_fields)]
return Struct(fields)
else:
msg = f"unsupported data type: {dtype}"
raise InvalidArgument(msg)
4 changes: 2 additions & 2 deletions py-polars/polars/testing/parametric/strategies/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def create_list_strategy(
return lists(
inner_dtype,
select_from=select_from,
min_len=min_size,
max_len=max_size,
min_size=min_size,
max_size=max_size,
unique=unique,
)
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_allow_infinities_deprecated(data: st.DataObject) -> None:
strategy=lists(
inner_dtype=pl.List(pl.String),
select_from=["aa", "bb", "cc"],
min_len=1,
min_size=1,
),
),
],
Expand Down

0 comments on commit f305c8c

Please sign in to comment.