From f305c8c2630961f098fc78a9f64ffde7ab7bb0e7 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 14 May 2024 16:30:56 +0200 Subject: [PATCH] fix(python): Fix some issues in parametric testing with nested dtypes (#16211) --- .../testing/parametric/strategies/data.py | 29 +++---- .../testing/parametric/strategies/dtype.py | 81 ++++++++++++------- .../testing/parametric/strategies/legacy.py | 4 +- .../parametric/strategies/test_core.py | 2 +- 4 files changed, 69 insertions(+), 47 deletions(-) diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 71313436dc36..bf159d7d2601 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -226,8 +226,8 @@ def lists( inner_dtype: DataType, *, select_from: Sequence[Any] | None = None, - min_len: int = 0, - max_len: int | None = None, + min_size: int = 0, + max_size: int | None = None, unique: bool = False, **kwargs: Any, ) -> SearchStrategy[list[Any]]: @@ -242,12 +242,12 @@ def lists( select_from The values to use for the innermost lists. If set to `None` (default), the default strategy associated with the innermost data type is used. - min_len + min_size The minimum length of the generated lists. - max_len + max_size The maximum length of the generated lists. If set to `None` (default), the - maximum is set based on `min_size`: `3` if `min_len` is zero, - otherwise `2 * min_len`. + maximum is set based on `min_size`: `3` if `min_size` is zero, + otherwise `2 * min_size`. unique Ensure that the generated lists contain unique values. **kwargs @@ -257,8 +257,8 @@ def lists( -------- ... """ - if max_len is None: - max_len = _DEFAULT_LIST_LEN_LIMIT if min_len == 0 else min_len * 2 + if max_size is None: + max_size = _DEFAULT_LIST_LEN_LIMIT if min_size == 0 else min_size * 2 if select_from is not None and not inner_dtype.is_nested(): inner_strategy = st.sampled_from(select_from) @@ -266,16 +266,16 @@ def lists( inner_strategy = data( inner_dtype, select_from=select_from, - min_size=min_len, - max_size=max_len, + min_size=min_size, + max_size=max_size, unique=unique, **kwargs, ) return st.lists( elements=inner_strategy, - min_size=min_len, - max_size=max_len, + min_size=min_size, + max_size=max_size, unique_by=(flexhash if unique else None), ) @@ -377,10 +377,11 @@ def data( elif dtype == Array: inner = getattr(dtype, "inner", None) or Null() width = getattr(dtype, "width", _DEFAULT_ARRAY_WIDTH_LIMIT) + kwargs = {k: v for k, v in kwargs.items() if k not in ("min_size", "max_size")} strategy = lists( inner, - min_len=width, - max_len=width, + min_size=width, + max_size=width, allow_null=allow_null, **kwargs, ) diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index dac7049def65..835bb4d8103d 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -16,6 +16,7 @@ Decimal, Duration, Enum, + Field, Float32, Float64, Int8, @@ -222,8 +223,7 @@ def _instantiate_nested_dtype( ) -> DataType: """Take a nested data type and instantiate it.""" - def instantiate_inner(dtype: PolarsDataType) -> DataType: - inner_dtype = getattr(dtype, "inner", None) + def instantiate_inner(inner_dtype: PolarsDataType | None) -> DataType: if inner_dtype is None: return draw(inner) elif inner_dtype.is_nested(): @@ -232,10 +232,10 @@ def instantiate_inner(dtype: PolarsDataType) -> DataType: return draw(_instantiate_flat_dtype(inner_dtype)) if dtype == List: - inner_dtype = instantiate_inner(dtype) + inner_dtype = instantiate_inner(getattr(dtype, "inner", None)) return List(inner_dtype) elif dtype == Array: - inner_dtype = instantiate_inner(dtype) + inner_dtype = instantiate_inner(getattr(dtype, "inner", None)) width = getattr( dtype, "width", @@ -243,13 +243,14 @@ def instantiate_inner(dtype: PolarsDataType) -> DataType: ) return Array(inner_dtype, width) elif dtype == Struct: - # TODO: Recursively instantiate struct field dtypes - if isinstance(dtype, DataType): - return dtype - n_fields = draw( - st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) - ) - return Struct({f"f{i}": draw(inner) for i in range(n_fields)}) + if isinstance(dtype, Struct): + fields = [Field(f.name, instantiate_inner(f.dtype)) for f in dtype.fields] + else: + n_fields = draw( + st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) + ) + fields = [Field(f"f{i}", draw(inner)) for i in range(n_fields)] + return Struct(fields) else: msg = f"unsupported data type: {dtype}" raise InvalidArgument(msg) @@ -276,19 +277,23 @@ def _instantiate_dtype( ) -> DataType: """Take a data type and instantiate it.""" if not dtype.is_nested(): + if isinstance(dtype, DataType): + return dtype + if allowed_dtypes is None: allowed_dtypes = [dtype] else: - allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype] + same_dtypes = [dt for dt in allowed_dtypes if dt == dtype] + allowed_dtypes = same_dtypes if same_dtypes else [dtype] + return draw( _flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) ) - def draw_inner(dtype: PolarsDataType) -> DataType: - if isinstance(dtype, DataType): + def draw_inner(dtype: PolarsDataType | None) -> DataType: + if dtype is None: return draw( - _instantiate_dtype( - dtype.inner, # type: ignore[attr-defined] + dtypes( allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes, nesting_level=nesting_level - 1, @@ -296,7 +301,8 @@ def draw_inner(dtype: PolarsDataType) -> DataType: ) else: return draw( - dtypes( + _instantiate_dtype( + dtype, allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes, nesting_level=nesting_level - 1, @@ -304,10 +310,10 @@ def draw_inner(dtype: PolarsDataType) -> DataType: ) if dtype == List: - inner = draw_inner(dtype) + inner = draw_inner(getattr(dtype, "inner", None)) return List(inner) elif dtype == Array: - inner = draw_inner(dtype) + inner = draw_inner(getattr(dtype, "inner", None)) width = getattr( dtype, "width", @@ -315,17 +321,32 @@ def draw_inner(dtype: PolarsDataType) -> DataType: ) return Array(inner, width) elif dtype == Struct: - if isinstance(dtype, DataType): - return dtype - n_fields = draw( - st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) - ) - inner_strategy = dtypes( - allowed_dtypes=allowed_dtypes, - excluded_dtypes=excluded_dtypes, - nesting_level=nesting_level - 1, - ) - return Struct({f"f{i}": draw(inner_strategy) for i in range(n_fields)}) + if isinstance(dtype, Struct): + fields = [ + Field( + name=f.name, + dtype=draw( + _instantiate_dtype( + f.dtype, + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) + ), + ) + for f in dtype.fields + ] + else: + n_fields = draw( + st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) + ) + inner_strategy = dtypes( + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) + fields = [Field(f"f{i}", draw(inner_strategy)) for i in range(n_fields)] + return Struct(fields) else: msg = f"unsupported data type: {dtype}" raise InvalidArgument(msg) diff --git a/py-polars/polars/testing/parametric/strategies/legacy.py b/py-polars/polars/testing/parametric/strategies/legacy.py index 434d791c604b..46f0cc1188e9 100644 --- a/py-polars/polars/testing/parametric/strategies/legacy.py +++ b/py-polars/polars/testing/parametric/strategies/legacy.py @@ -150,7 +150,7 @@ def create_list_strategy( return lists( inner_dtype, select_from=select_from, - min_len=min_size, - max_len=max_size, + min_size=min_size, + max_size=max_size, unique=unique, ) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index cec48b8acbd9..08aac7fb75a1 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -199,7 +199,7 @@ def test_allow_infinities_deprecated(data: st.DataObject) -> None: strategy=lists( inner_dtype=pl.List(pl.String), select_from=["aa", "bb", "cc"], - min_len=1, + min_size=1, ), ), ],