From 8a6bf4bc58e7fed9b6728bad66e0590fccb11f0e Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 20 Jun 2024 13:51:24 +0200 Subject: [PATCH] refactor(python): Remove re-export of data type groups (#17073) --- py-polars/polars/__init__.py | 34 +++--- py-polars/polars/_utils/various.py | 3 +- py-polars/polars/dataframe/frame.py | 28 ++--- py-polars/polars/datatypes/__init__.py | 22 +--- py-polars/polars/datatypes/classes.py | 32 ----- py-polars/polars/datatypes/constants.py | 81 +----------- py-polars/polars/datatypes/group.py | 115 ++++++++++++++++++ py-polars/polars/functions/repeat.py | 3 +- py-polars/polars/io/database/_inference.py | 6 +- .../polars/io/spreadsheet/_write_utils.py | 3 +- py-polars/polars/io/spreadsheet/functions.py | 4 +- py-polars/polars/lazyframe/frame.py | 22 ++-- py-polars/polars/selectors.py | 62 +++++----- py-polars/polars/testing/asserts/series.py | 2 +- py-polars/tests/unit/conftest.py | 13 ++ py-polars/tests/unit/dataframe/test_df.py | 3 +- py-polars/tests/unit/datatypes/test_list.py | 27 ++-- .../tests/unit/datatypes/test_temporal.py | 17 +-- py-polars/tests/unit/expr/test_exprs.py | 4 +- py-polars/tests/unit/io/test_spreadsheet.py | 7 +- .../tests/unit/lazyframe/test_lazyframe.py | 16 +-- py-polars/tests/unit/ml/test_to_torch.py | 2 +- .../operations/arithmetic/test_arithmetic.py | 52 ++++---- .../unit/operations/namespaces/test_meta.py | 3 +- .../unit/series/buffers/test_from_buffer.py | 3 +- .../unit/series/buffers/test_from_buffers.py | 26 ++-- .../series/buffers/test_get_buffer_info.py | 9 +- .../unit/streaming/test_streaming_group_by.py | 3 +- py-polars/tests/unit/test_datatypes.py | 42 +++---- py-polars/tests/unit/test_errors.py | 19 +-- py-polars/tests/unit/test_expansion.py | 2 +- py-polars/tests/unit/test_init.py | 7 ++ py-polars/tests/unit/test_selectors.py | 15 +-- 33 files changed, 348 insertions(+), 339 deletions(-) create mode 100644 py-polars/polars/datatypes/group.py diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 215780fada83..3a00192e7827 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import contextlib import os @@ -40,13 +38,6 @@ ) from polars.dataframe import DataFrame from polars.datatypes import ( - DATETIME_DTYPES, - DURATION_DTYPES, - FLOAT_DTYPES, - INTEGER_DTYPES, - NESTED_DTYPES, - NUMERIC_DTYPES, - TEMPORAL_DTYPES, Array, Binary, Boolean, @@ -250,14 +241,6 @@ "UInt64", "Unknown", "Utf8", - # polars.datatypes: dtype groups - "DATETIME_DTYPES", - "DURATION_DTYPES", - "FLOAT_DTYPES", - "INTEGER_DTYPES", - "NESTED_DTYPES", - "NUMERIC_DTYPES", - "TEMPORAL_DTYPES", # polars.io "read_avro", "read_clipboard", @@ -401,7 +384,7 @@ os.environ["POLARS_ALLOW_EXTENSION"] = "true" -def __getattr__(name: str) -> type[Exception]: +def __getattr__(name: str): # type: ignore[no-untyped-def] # Deprecate re-export of exceptions at top-level if name in dir(exceptions): from polars._utils.deprecation import issue_deprecation_warning @@ -416,5 +399,20 @@ def __getattr__(name: str) -> type[Exception]: ) return getattr(exceptions, name) + # Deprecate data type groups at top-level + import polars.datatypes.group as dtgroup + + if name in dir(dtgroup): + from polars._utils.deprecation import issue_deprecation_warning + + issue_deprecation_warning( + message=( + f"`{name}` is deprecated. Define your own data type groups or use the" + " `polars.selectors` module for selecting columns of a certain data type." + ), + version="1.0.0", + ) + return getattr(dtgroup, name) + msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index e1e1c0b69fb4..09701fd677cc 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -23,8 +23,6 @@ import polars as pl from polars import functions as F from polars.datatypes import ( - FLOAT_DTYPES, - INTEGER_DTYPES, Boolean, Date, Datetime, @@ -34,6 +32,7 @@ String, Time, ) +from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES from polars.dependencies import _check_for_numpy from polars.dependencies import numpy as np diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index ba3829f36192..ec0ce54249ba 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -61,7 +61,6 @@ from polars.dataframe._html import NotebookFormatter from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy from polars.datatypes import ( - INTEGER_DTYPES, N_INFER_DEFAULT, Boolean, Float32, @@ -75,6 +74,7 @@ UInt32, UInt64, ) +from polars.datatypes.group import INTEGER_DTYPES from polars.dependencies import ( _GREAT_TABLES_AVAILABLE, _HVPLOT_AVAILABLE, @@ -2784,9 +2784,7 @@ def write_excel( dtype_formats : dict A `{dtype:str,}` dictionary that sets the default Excel format for the given dtype. (This can be overridden on a per-column basis by the - `column_formats` param). It is also valid to use dtype groups such as - `pl.FLOAT_DTYPES` as the dtype/format key, to simplify setting uniform - integer and float formats. + `column_formats` param). conditional_formats : dict A dictionary of colname (or selector) keys to a format str, dict, or list that defines conditional formatting options for the specified columns. @@ -3022,7 +3020,7 @@ def write_excel( >>> df.write_excel( # doctest: +SKIP ... table_style="Table Style Light 2", ... # apply accounting format to all flavours of integer - ... dtype_formats={pl.INTEGER_DTYPES: "#,##0_);(#,##0)"}, + ... dtype_formats={dt: "#,##0_);(#,##0)" for dt in [pl.Int32, pl.Int64]}, ... sparklines={ ... # default options; just provide source cols ... "trend": ["q1", "q2", "q3", "q4"], @@ -8459,18 +8457,18 @@ def select( >>> with pl.Config(auto_structify=True): ... df.select( - ... is_odd=(pl.col(pl.INTEGER_DTYPES) % 2).name.suffix("_is_odd"), + ... is_odd=(pl.col(pl.Int64) % 2 == 1).name.suffix("_is_odd"), ... ) shape: (3, 1) - ┌───────────┐ - │ is_odd │ - │ --- │ - │ struct[2] │ - ╞═══════════╡ - │ {1,0} │ - │ {0,1} │ - │ {1,0} │ - └───────────┘ + ┌──────────────┐ + │ is_odd │ + │ --- │ + │ struct[2] │ + ╞══════════════╡ + │ {true,false} │ + │ {false,true} │ + │ {true,false} │ + └──────────────┘ """ return self.lazy().select(*exprs, **named_exprs).collect(_eager=True) diff --git a/py-polars/polars/datatypes/__init__.py b/py-polars/polars/datatypes/__init__.py index aa9ebb507ded..fc36c47de77b 100644 --- a/py-polars/polars/datatypes/__init__.py +++ b/py-polars/polars/datatypes/__init__.py @@ -5,7 +5,6 @@ Categorical, DataType, DataTypeClass, - DataTypeGroup, Date, Datetime, Decimal, @@ -34,17 +33,8 @@ Utf8, ) from polars.datatypes.constants import ( - DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, - DURATION_DTYPES, - FLOAT_DTYPES, - INTEGER_DTYPES, N_INFER_DEFAULT, - NESTED_DTYPES, - NUMERIC_DTYPES, - SIGNED_INTEGER_DTYPES, - TEMPORAL_DTYPES, - UNSIGNED_INTEGER_DTYPES, ) from polars.datatypes.constructor import ( numpy_type_to_constructor, @@ -72,7 +62,6 @@ "Categorical", "DataType", "DataTypeClass", - "DataTypeGroup", "Date", "Datetime", "Decimal", @@ -100,17 +89,8 @@ "Unknown", "Utf8", # constants - "DATETIME_DTYPES", - "DTYPE_TEMPORAL_UNITS", - "DURATION_DTYPES", - "FLOAT_DTYPES", - "INTEGER_DTYPES", - "NESTED_DTYPES", - "NUMERIC_DTYPES", "N_INFER_DEFAULT", - "SIGNED_INTEGER_DTYPES", - "TEMPORAL_DTYPES", - "UNSIGNED_INTEGER_DTYPES", + "DTYPE_TEMPORAL_UNITS", # constructor "numpy_type_to_constructor", "numpy_values_and_dtype", diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index b186d7348cf9..cf2be33f2cbc 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -181,38 +181,6 @@ def is_nested(cls) -> bool: return issubclass(cls, NestedType) -class DataTypeGroup(frozenset): # type: ignore[type-arg] - """Group of data types.""" - - _match_base_type: bool - - def __new__( - cls, items: Iterable[DataType | DataTypeClass], *, match_base_type: bool = True - ) -> DataTypeGroup: - """ - Construct a DataTypeGroup. - - Parameters - ---------- - items : - iterable of data types - match_base_type: - match the base type - """ - for it in items: - if not isinstance(it, (DataType, DataTypeClass)): - msg = f"DataTypeGroup items must be dtypes; found {type(it).__name__!r}" - raise TypeError(msg) - dtype_group = super().__new__(cls, items) - dtype_group._match_base_type = match_base_type - return dtype_group - - def __contains__(self, item: Any) -> bool: - if self._match_base_type and isinstance(item, (DataType, DataTypeClass)): - item = item.base_type() - return super().__contains__(item) - - class NumericType(DataType): """Base class for numeric data types.""" diff --git a/py-polars/polars/datatypes/constants.py b/py-polars/polars/datatypes/constants.py index 2a3716697560..b0eed7097f61 100644 --- a/py-polars/polars/datatypes/constants.py +++ b/py-polars/polars/datatypes/constants.py @@ -2,85 +2,10 @@ from typing import TYPE_CHECKING -from polars.datatypes import ( - Array, - DataTypeGroup, - Date, - Datetime, - Decimal, - Duration, - Float32, - Float64, - Int8, - Int16, - Int32, - Int64, - List, - Struct, - Time, - UInt8, - UInt16, - UInt32, - UInt64, -) - if TYPE_CHECKING: - from polars.type_aliases import ( - PolarsDataType, - PolarsIntegerType, - PolarsTemporalType, - TimeUnit, - ) + from polars.type_aliases import TimeUnit +# Number of rows to scan by default when inferring datatypes +N_INFER_DEFAULT = 100 DTYPE_TEMPORAL_UNITS: frozenset[TimeUnit] = frozenset(["ns", "us", "ms"]) -DATETIME_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( - [ - Datetime, - Datetime("ms"), - Datetime("us"), - Datetime("ns"), - Datetime("ms", "*"), - Datetime("us", "*"), - Datetime("ns", "*"), - ] -) -DURATION_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( - [ - Duration, - Duration("ms"), - Duration("us"), - Duration("ns"), - ] -) -TEMPORAL_DTYPES: frozenset[PolarsTemporalType] = DataTypeGroup( - frozenset([Date, Time]) | DATETIME_DTYPES | DURATION_DTYPES -) -SIGNED_INTEGER_DTYPES: frozenset[PolarsIntegerType] = DataTypeGroup( - [ - Int8, - Int16, - Int32, - Int64, - ] -) -UNSIGNED_INTEGER_DTYPES: frozenset[PolarsIntegerType] = DataTypeGroup( - [ - UInt8, - UInt16, - UInt32, - UInt64, - ] -) -INTEGER_DTYPES: frozenset[PolarsIntegerType] = ( - SIGNED_INTEGER_DTYPES | UNSIGNED_INTEGER_DTYPES -) -FLOAT_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([Float32, Float64]) -NUMERIC_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( - FLOAT_DTYPES | INTEGER_DTYPES | frozenset([Decimal]) -) - -NESTED_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([List, Struct, Array]) - -# number of rows to scan by default when inferring datatypes -N_INFER_DEFAULT = 100 diff --git a/py-polars/polars/datatypes/group.py b/py-polars/polars/datatypes/group.py new file mode 100644 index 000000000000..aafe7ad52b81 --- /dev/null +++ b/py-polars/polars/datatypes/group.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable + +from polars.datatypes.classes import ( + Array, + DataType, + DataTypeClass, + Date, + Datetime, + Decimal, + Duration, + Float32, + Float64, + Int8, + Int16, + Int32, + Int64, + List, + Struct, + Time, + UInt8, + UInt16, + UInt32, + UInt64, +) + +if TYPE_CHECKING: + from polars.type_aliases import ( + PolarsDataType, + PolarsIntegerType, + PolarsTemporalType, + ) + + +class DataTypeGroup(frozenset): # type: ignore[type-arg] + """Group of data types.""" + + _match_base_type: bool + + def __new__( + cls, items: Iterable[DataType | DataTypeClass], *, match_base_type: bool = True + ) -> DataTypeGroup: + """ + Construct a DataTypeGroup. + + Parameters + ---------- + items : + iterable of data types + match_base_type: + match the base type + """ + for it in items: + if not isinstance(it, (DataType, DataTypeClass)): + msg = f"DataTypeGroup items must be dtypes; found {type(it).__name__!r}" + raise TypeError(msg) + dtype_group = super().__new__(cls, items) + dtype_group._match_base_type = match_base_type + return dtype_group + + def __contains__(self, item: Any) -> bool: + if self._match_base_type and isinstance(item, (DataType, DataTypeClass)): + item = item.base_type() + return super().__contains__(item) + + +SIGNED_INTEGER_DTYPES: frozenset[PolarsIntegerType] = DataTypeGroup( + [ + Int8, + Int16, + Int32, + Int64, + ] +) +UNSIGNED_INTEGER_DTYPES: frozenset[PolarsIntegerType] = DataTypeGroup( + [ + UInt8, + UInt16, + UInt32, + UInt64, + ] +) +INTEGER_DTYPES: frozenset[PolarsIntegerType] = ( + SIGNED_INTEGER_DTYPES | UNSIGNED_INTEGER_DTYPES +) +FLOAT_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([Float32, Float64]) +NUMERIC_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( + FLOAT_DTYPES | INTEGER_DTYPES | frozenset([Decimal]) +) + +DATETIME_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( + [ + Datetime, + Datetime("ms"), + Datetime("us"), + Datetime("ns"), + Datetime("ms", "*"), + Datetime("us", "*"), + Datetime("ns", "*"), + ] +) +DURATION_DTYPES: frozenset[PolarsDataType] = DataTypeGroup( + [ + Duration, + Duration("ms"), + Duration("us"), + Duration("ns"), + ] +) +TEMPORAL_DTYPES: frozenset[PolarsTemporalType] = DataTypeGroup( + frozenset([Date, Time]) | DATETIME_DTYPES | DURATION_DTYPES +) + +NESTED_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([List, Struct, Array]) diff --git a/py-polars/polars/functions/repeat.py b/py-polars/polars/functions/repeat.py index b92077972b39..c23aa8c2fac6 100644 --- a/py-polars/polars/functions/repeat.py +++ b/py-polars/polars/functions/repeat.py @@ -9,8 +9,6 @@ from polars._utils.parse import parse_into_expression from polars._utils.wrap import wrap_expr from polars.datatypes import ( - FLOAT_DTYPES, - INTEGER_DTYPES, Array, Boolean, Decimal, @@ -18,6 +16,7 @@ List, Utf8, ) +from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/io/database/_inference.py b/py-polars/polars/io/database/_inference.py index e69e8c90796a..2a34eeba8314 100644 --- a/py-polars/polars/io/database/_inference.py +++ b/py-polars/polars/io/database/_inference.py @@ -7,8 +7,6 @@ from typing import TYPE_CHECKING, Any from polars.datatypes import ( - INTEGER_DTYPES, - UNSIGNED_INTEGER_DTYPES, Binary, Boolean, Date, @@ -31,6 +29,10 @@ UInt64, ) from polars.datatypes.convert import _map_py_type_to_dtype +from polars.datatypes.group import ( + INTEGER_DTYPES, + UNSIGNED_INTEGER_DTYPES, +) if TYPE_CHECKING: from polars.type_aliases import PolarsDataType diff --git a/py-polars/polars/io/spreadsheet/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py index a12839b15d25..3dc8e15a206f 100644 --- a/py-polars/polars/io/spreadsheet/_write_utils.py +++ b/py-polars/polars/io/spreadsheet/_write_utils.py @@ -6,8 +6,6 @@ from polars import functions as F from polars.datatypes import ( - FLOAT_DTYPES, - INTEGER_DTYPES, Date, Datetime, Float64, @@ -16,6 +14,7 @@ Struct, Time, ) +from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES from polars.dependencies import json from polars.exceptions import DuplicateError from polars.selectors import _expand_selector_dicts, _expand_selectors diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 08468cee7183..26af04440cd1 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -14,10 +14,7 @@ ) from polars._utils.various import normalize_filepath, parse_version from polars.datatypes import ( - FLOAT_DTYPES, - INTEGER_DTYPES, N_INFER_DEFAULT, - NUMERIC_DTYPES, Boolean, Date, Datetime, @@ -26,6 +23,7 @@ Null, String, ) +from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES, NUMERIC_DTYPES from polars.dependencies import import_optional from polars.exceptions import ( ModuleUpgradeRequiredError, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index b520be9b1a80..698884360910 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -52,7 +52,6 @@ N_INFER_DEFAULT, Boolean, Categorical, - DataTypeGroup, Date, Datetime, Duration, @@ -75,6 +74,7 @@ is_polars_dtype, py_type_to_dtype, ) +from polars.datatypes.group import DataTypeGroup from polars.dependencies import import_optional, subprocess from polars.exceptions import PerformanceWarning from polars.lazyframe.group_by import LazyGroupBy @@ -3099,18 +3099,18 @@ def select( >>> with pl.Config(auto_structify=True): ... lf.select( - ... is_odd=(pl.col(pl.INTEGER_DTYPES) % 2).name.suffix("_is_odd"), + ... is_odd=(pl.col(pl.Int64) % 2 == 1).name.suffix("_is_odd"), ... ).collect() shape: (3, 1) - ┌───────────┐ - │ is_odd │ - │ --- │ - │ struct[2] │ - ╞═══════════╡ - │ {1,0} │ - │ {0,1} │ - │ {1,0} │ - └───────────┘ + ┌──────────────┐ + │ is_odd │ + │ --- │ + │ struct[2] │ + ╞══════════════╡ + │ {true,false} │ + │ {false,true} │ + │ {true,false} │ + └──────────────┘ """ structify = bool(int(os.environ.get("POLARS_AUTO_STRUCTIFY", 0))) diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 9275692de43d..bf22336a273c 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -18,12 +18,6 @@ from polars._utils.parse.expr import _parse_inputs_as_iterable from polars._utils.various import is_column, re_escape from polars.datatypes import ( - FLOAT_DTYPES, - INTEGER_DTYPES, - NUMERIC_DTYPES, - SIGNED_INTEGER_DTYPES, - TEMPORAL_DTYPES, - UNSIGNED_INTEGER_DTYPES, Binary, Boolean, Categorical, @@ -36,6 +30,14 @@ Time, is_polars_dtype, ) +from polars.datatypes.group import ( + FLOAT_DTYPES, + INTEGER_DTYPES, + NUMERIC_DTYPES, + SIGNED_INTEGER_DTYPES, + TEMPORAL_DTYPES, + UNSIGNED_INTEGER_DTYPES, +) from polars.expr import Expr if TYPE_CHECKING: @@ -895,33 +897,33 @@ def by_dtype( ... } ... ) - Select all columns with date or integer dtypes: + Select all columns with date or string dtypes: - >>> df.select(cs.by_dtype(pl.Date, pl.INTEGER_DTYPES)) + >>> df.select(cs.by_dtype(pl.Date, pl.String)) shape: (3, 2) - ┌────────────┬──────────┐ - │ dt ┆ value │ - │ --- ┆ --- │ - │ date ┆ i64 │ - ╞════════════╪══════════╡ - │ 1999-12-31 ┆ 1234500 │ - │ 2024-01-01 ┆ 5000555 │ - │ 2010-07-05 ┆ -4500000 │ - └────────────┴──────────┘ - - Select all columns that are not of date or integer dtype: - - >>> df.select(~cs.by_dtype(pl.Date, pl.INTEGER_DTYPES)) + ┌────────────┬───────┐ + │ dt ┆ other │ + │ --- ┆ --- │ + │ date ┆ str │ + ╞════════════╪═══════╡ + │ 1999-12-31 ┆ foo │ + │ 2024-01-01 ┆ bar │ + │ 2010-07-05 ┆ foo │ + └────────────┴───────┘ + + Select all columns that are not of date or string dtype: + + >>> df.select(~cs.by_dtype(pl.Date, pl.String)) shape: (3, 1) - ┌───────┐ - │ other │ - │ --- │ - │ str │ - ╞═══════╡ - │ foo │ - │ bar │ - │ foo │ - └───────┘ + ┌──────────┐ + │ value │ + │ --- │ + │ i64 │ + ╞══════════╡ + │ 1234500 │ + │ 5000555 │ + │ -4500000 │ + └──────────┘ Group by string columns and sum the numeric columns: diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index dd64b88fbfb1..ad316f565aad 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -4,7 +4,6 @@ from polars._utils.deprecation import deprecate_renamed_parameter from polars.datatypes import ( - FLOAT_DTYPES, Array, Categorical, List, @@ -12,6 +11,7 @@ Struct, unpack_dtypes, ) +from polars.datatypes.group import FLOAT_DTYPES from polars.exceptions import ComputeError, InvalidOperationError from polars.series import Series from polars.testing.asserts.utils import raise_assertion_error diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py index c073699ce948..f5f848ef4573 100644 --- a/py-polars/tests/unit/conftest.py +++ b/py-polars/tests/unit/conftest.py @@ -18,6 +18,19 @@ profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] ) +# Data type groups +SIGNED_INTEGER_DTYPES = [pl.Int8(), pl.Int16(), pl.Int32(), pl.Int64()] +UNSIGNED_INTEGER_DTYPES = [pl.UInt8(), pl.UInt16(), pl.UInt32(), pl.UInt64()] +INTEGER_DTYPES = SIGNED_INTEGER_DTYPES + UNSIGNED_INTEGER_DTYPES +FLOAT_DTYPES = [pl.Float32(), pl.Float64()] +NUMERIC_DTYPES = INTEGER_DTYPES + FLOAT_DTYPES + +DATETIME_DTYPES = [pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")] +DURATION_DTYPES = [pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")] +TEMPORAL_DTYPES = [*DATETIME_DTYPES, *DURATION_DTYPES, pl.Date(), pl.Time()] + +NESTED_DTYPES = [pl.List, pl.Struct, pl.Array] + @pytest.fixture() def partition_limit() -> int: diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index b41bba5f6adb..e55fc817f05e 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -16,7 +16,7 @@ import polars as pl import polars.selectors as cs from polars._utils.construction import iterable_to_pydf -from polars.datatypes import DTYPE_TEMPORAL_UNITS, INTEGER_DTYPES +from polars.datatypes import DTYPE_TEMPORAL_UNITS from polars.exceptions import ( ComputeError, DuplicateError, @@ -28,6 +28,7 @@ assert_frame_not_equal, assert_series_equal, ) +from tests.unit.conftest import INTEGER_DTYPES if TYPE_CHECKING: from zoneinfo import ZoneInfo diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index be017681f0dd..091a949f63bc 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -10,6 +10,7 @@ import polars as pl from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import NUMERIC_DTYPES if TYPE_CHECKING: from polars.type_aliases import PolarsDataType @@ -401,20 +402,22 @@ def test_list_any() -> None: } -def test_list_min_max() -> None: - for dt in pl.INTEGER_DTYPES | pl.FLOAT_DTYPES: - df = pl.DataFrame( - {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, - schema={"a": pl.List(dt)}, - ) - result = df.select(pl.col("a").list.min()) - expected = df.select(pl.col("a").list.first()) - assert_frame_equal(result, expected) +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_list_min_max(dtype: pl.DataType) -> None: + df = pl.DataFrame( + {"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}, + schema={"a": pl.List(dtype)}, + ) + result = df.select(pl.col("a").list.min()) + expected = df.select(pl.col("a").list.first()) + assert_frame_equal(result, expected) + + result = df.select(pl.col("a").list.max()) + expected = df.select(pl.col("a").list.last()) + assert_frame_equal(result, expected) - result = df.select(pl.col("a").list.max()) - expected = df.select(pl.col("a").list.last()) - assert_frame_equal(result, expected) +def test_list_min_max2() -> None: df = pl.DataFrame( {"a": [[1], [1, 5, -1, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]}, ) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index b2759769273d..da2f19a2ce6b 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -10,7 +10,7 @@ import pytest import polars as pl -from polars.datatypes import DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, TEMPORAL_DTYPES +from polars.datatypes import DTYPE_TEMPORAL_UNITS from polars.exceptions import ( ComputeError, InvalidOperationError, @@ -21,6 +21,7 @@ assert_series_equal, assert_series_not_equal, ) +from tests.unit.conftest import DATETIME_DTYPES, TEMPORAL_DTYPES if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -2310,13 +2311,13 @@ def test_year_null_backed_by_out_of_range_15313() -> None: assert_series_equal(result, expected) -def test_series_is_temporal() -> None: - for tp in TEMPORAL_DTYPES | { - pl.Datetime("ms", "UTC"), - pl.Datetime("ns", "Europe/Amsterdam"), - }: - s = pl.Series([None], dtype=tp) - assert s.dtype.is_temporal() is True +@pytest.mark.parametrize( + "dtype", + [*TEMPORAL_DTYPES, pl.Datetime("ms", "UTC"), pl.Datetime("ns", "Europe/Amsterdam")], +) +def test_series_is_temporal(dtype: pl.DataType) -> None: + s = pl.Series([None], dtype=dtype) + assert s.dtype.is_temporal() is True @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index d28e75c5793c..2561dcecb466 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -8,7 +8,8 @@ import pytest import polars as pl -from polars.datatypes import ( +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import ( DATETIME_DTYPES, DURATION_DTYPES, FLOAT_DTYPES, @@ -16,7 +17,6 @@ NUMERIC_DTYPES, TEMPORAL_DTYPES, ) -from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from zoneinfo import ZoneInfo diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 21c25569fe84..180d23c33216 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -14,6 +14,7 @@ from polars.exceptions import NoDataError, ParameterCollisionError from polars.io.spreadsheet.functions import _identify_workbook from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES, NUMERIC_DTYPES if TYPE_CHECKING: from polars.type_aliases import ExcelSpreadsheetEngine, SchemaDict, SelectorType @@ -597,7 +598,9 @@ def test_read_excel_all_sheets_with_sheet_name(path_xlsx: Path, engine: str) -> ], }, "dtype_formats": { - pl.FLOAT_DTYPES: '_(£* #,##0.00_);_(£* (#,##0.00);_(£* "-"??_);_(@_)', + frozenset( + FLOAT_DTYPES + ): '_(£* #,##0.00_);_(£* (#,##0.00);_(£* "-"??_);_(@_)', pl.Date: "dd-mm-yyyy", }, "column_formats": {"dtm": {"font_color": "#31869c", "bg_color": "#b7dee8"}}, @@ -692,7 +695,7 @@ def test_excel_sparklines(engine: ExcelSpreadsheetEngine) -> None: workbook=wb, worksheet="frame_data", table_style="Table Style Light 2", - dtype_formats={pl.INTEGER_DTYPES: "#,##0_);(#,##0)"}, + dtype_formats={frozenset(NUMERIC_DTYPES): "#,##0_);(#,##0)"}, column_formats={cs.starts_with("h"): "#,##0_);(#,##0)"}, sparklines={ "trend": ["q1", "q2", "q3", "q4"], diff --git a/py-polars/tests/unit/lazyframe/test_lazyframe.py b/py-polars/tests/unit/lazyframe/test_lazyframe.py index 68628c17c715..f4fb1a11403a 100644 --- a/py-polars/tests/unit/lazyframe/test_lazyframe.py +++ b/py-polars/tests/unit/lazyframe/test_lazyframe.py @@ -13,13 +13,13 @@ import polars as pl import polars.selectors as cs from polars import lit, when -from polars.datatypes import FLOAT_DTYPES from polars.exceptions import ( InvalidOperationError, PerformanceWarning, PolarsInefficientMapWarning, ) from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES if TYPE_CHECKING: from _pytest.capture import CaptureFixture @@ -527,13 +527,13 @@ def test_floor() -> None: (1.0e20, 2, 100000000000000000000.0), ], ) -def test_round(n: float, ndigits: int, expected: float) -> None: - for float_dtype in FLOAT_DTYPES: - ldf = pl.LazyFrame({"value": [n]}, schema_overrides={"value": float_dtype}) - assert_series_equal( - ldf.select(pl.col("value").round(decimals=ndigits)).collect().to_series(), - pl.Series("value", [expected], dtype=float_dtype), - ) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_round(n: float, ndigits: int, expected: float, dtype: pl.DataType) -> None: + ldf = pl.LazyFrame({"value": [n]}, schema_overrides={"value": dtype}) + assert_series_equal( + ldf.select(pl.col("value").round(decimals=ndigits)).collect().to_series(), + pl.Series("value", [expected], dtype=dtype), + ) def test_dot() -> None: diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py index 7f1a4711c8ac..c42c6f2ea666 100644 --- a/py-polars/tests/unit/ml/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -235,7 +235,7 @@ def test_to_dataset_half_precision(df: pl.DataFrame) -> None: [ ("x", None), ("x", ["y", "z"]), - (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), + (cs.integer(), ~cs.integer()), ], ) def test_to_torch_labelled_dataset(label: Any, features: Any, df: pl.DataFrame) -> None: diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 6ac9c65fc729..0abe473fa970 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -19,9 +19,9 @@ UInt32, UInt64, ) -from polars.datatypes import FLOAT_DTYPES, INTEGER_DTYPES from polars.exceptions import ColumnNotFoundError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES def test_sqrt_neg_inf() -> None: @@ -266,24 +266,24 @@ def test_arithmetic_null_count() -> None: operator.sub, ], ) -def test_operator_arithmetic_with_nulls(op: Any) -> None: - for dtype in FLOAT_DTYPES | INTEGER_DTYPES: - df = pl.DataFrame({"n": [2, 3]}, schema={"n": dtype}) - s = df.to_series() - - df_expected = pl.DataFrame({"n": [None, None]}, schema={"n": dtype}) - s_expected = df_expected.to_series() - - # validate expr, frame, and series behaviour with null value arithmetic - op_name = op.__name__ - for null_expr in (None, pl.lit(None)): - assert_frame_equal(df_expected, df.select(op(pl.col("n"), null_expr))) - assert_frame_equal( - df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr)) - ) +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_operator_arithmetic_with_nulls(op: Any, dtype: pl.DataType) -> None: + df = pl.DataFrame({"n": [2, 3]}, schema={"n": dtype}) + s = df.to_series() + + df_expected = pl.DataFrame({"n": [None, None]}, schema={"n": dtype}) + s_expected = df_expected.to_series() + + # validate expr, frame, and series behaviour with null value arithmetic + op_name = op.__name__ + for null_expr in (None, pl.lit(None)): + assert_frame_equal(df_expected, df.select(op(pl.col("n"), null_expr))) + assert_frame_equal( + df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr)) + ) - assert_frame_equal(df_expected, op(df, None)) - assert_series_equal(s_expected, op(s, None)) + assert_frame_equal(df_expected, op(df, None)) + assert_series_equal(s_expected, op(s, None)) @pytest.mark.parametrize( @@ -613,14 +613,14 @@ def test_literal_subtract_schema_13284() -> None: ).collect_schema() == OrderedDict([("a", pl.UInt8), ("len", pl.UInt32)]) -def test_int_operator_stability() -> None: - for dt in pl.datatypes.INTEGER_DTYPES: - s = pl.Series(values=[10], dtype=dt) - assert pl.select(pl.lit(s) // 2).dtypes == [dt] - assert pl.select(pl.lit(s) + 2).dtypes == [dt] - assert pl.select(pl.lit(s) - 2).dtypes == [dt] - assert pl.select(pl.lit(s) * 2).dtypes == [dt] - assert pl.select(pl.lit(s) / 2).dtypes == [pl.Float64] +@pytest.mark.parametrize("dtype", INTEGER_DTYPES) +def test_int_operator_stability(dtype: pl.DataType) -> None: + s = pl.Series(values=[10], dtype=dtype) + assert pl.select(pl.lit(s) // 2).dtypes == [dtype] + assert pl.select(pl.lit(s) + 2).dtypes == [dtype] + assert pl.select(pl.lit(s) - 2).dtypes == [dtype] + assert pl.select(pl.lit(s) * 2).dtypes == [dtype] + assert pl.select(pl.lit(s) / 2).dtypes == [pl.Float64] def test_duration_division_schema() -> None: diff --git a/py-polars/tests/unit/operations/namespaces/test_meta.py b/py-polars/tests/unit/operations/namespaces/test_meta.py index 1a707db99bf6..2772088cfd81 100644 --- a/py-polars/tests/unit/operations/namespaces/test_meta.py +++ b/py-polars/tests/unit/operations/namespaces/test_meta.py @@ -7,6 +7,7 @@ import polars as pl import polars.selectors as cs from polars.exceptions import ComputeError +from tests.unit.conftest import NUMERIC_DTYPES if TYPE_CHECKING: from pathlib import Path @@ -88,7 +89,7 @@ def test_is_column() -> None: # columns (pl.col("foo"), True), (pl.col("foo", "bar"), True), - (pl.col(pl.NUMERIC_DTYPES), True), + (pl.col(NUMERIC_DTYPES), True), # column expressions (pl.col("foo") + 100, False), (pl.col("foo").floordiv(10), False), diff --git a/py-polars/tests/unit/series/buffers/test_from_buffer.py b/py-polars/tests/unit/series/buffers/test_from_buffer.py index 99588293dd42..7bdb6e45c0b9 100644 --- a/py-polars/tests/unit/series/buffers/test_from_buffer.py +++ b/py-polars/tests/unit/series/buffers/test_from_buffer.py @@ -8,11 +8,12 @@ import polars as pl from polars.testing import assert_series_equal from polars.testing.parametric import series +from tests.unit.conftest import NUMERIC_DTYPES @given( s=series( - allowed_dtypes=(pl.INTEGER_DTYPES | pl.FLOAT_DTYPES | {pl.Boolean}), + allowed_dtypes=[*NUMERIC_DTYPES, pl.Boolean], allow_chunks=False, allow_null=False, ) diff --git a/py-polars/tests/unit/series/buffers/test_from_buffers.py b/py-polars/tests/unit/series/buffers/test_from_buffers.py index 14659d1f7b3f..092864145071 100644 --- a/py-polars/tests/unit/series/buffers/test_from_buffers.py +++ b/py-polars/tests/unit/series/buffers/test_from_buffers.py @@ -10,20 +10,7 @@ from polars.exceptions import PanicException from polars.testing import assert_series_equal from polars.testing.parametric import series - -if TYPE_CHECKING: - from polars.type_aliases import PolarsDataType - -# TODO: Define data type groups centrally somewhere in the test suite -DATETIME_DTYPES: set[PolarsDataType] = { - pl.Datetime, - pl.Datetime("ms"), - pl.Datetime("us"), - pl.Datetime("ns"), -} -TEMPORAL_DTYPES: set[PolarsDataType] = ( - {pl.Date, pl.Time} | pl.DURATION_DTYPES | DATETIME_DTYPES -) +from tests.unit.conftest import NUMERIC_DTYPES if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -34,7 +21,7 @@ @given( s=series( - allowed_dtypes=(pl.INTEGER_DTYPES | pl.FLOAT_DTYPES | {pl.Boolean}), + allowed_dtypes=[*NUMERIC_DTYPES, pl.Boolean], allow_chunks=False, ) ) @@ -46,7 +33,7 @@ def test_series_from_buffers_numeric_with_validity(s: pl.Series) -> None: @given( s=series( - allowed_dtypes=(pl.INTEGER_DTYPES | pl.FLOAT_DTYPES | {pl.Boolean}), + allowed_dtypes=[*NUMERIC_DTYPES, pl.Boolean], allow_chunks=False, allow_null=False, ) @@ -56,7 +43,12 @@ def test_series_from_buffers_numeric(s: pl.Series) -> None: assert_series_equal(s, result) -@given(s=series(allowed_dtypes=TEMPORAL_DTYPES, allow_chunks=False)) +@given( + s=series( + allowed_dtypes=[pl.Date, pl.Time, pl.Datetime, pl.Duration], + allow_chunks=False, + ) +) def test_series_from_buffers_temporal_with_validity(s: pl.Series) -> None: validity = s.is_not_null() physical = pl.Int32 if s.dtype == pl.Date else pl.Int64 diff --git a/py-polars/tests/unit/series/buffers/test_get_buffer_info.py b/py-polars/tests/unit/series/buffers/test_get_buffer_info.py index 738186e38fd4..71cab47d95a8 100644 --- a/py-polars/tests/unit/series/buffers/test_get_buffer_info.py +++ b/py-polars/tests/unit/series/buffers/test_get_buffer_info.py @@ -2,12 +2,13 @@ import polars as pl from polars.exceptions import ComputeError +from tests.unit.conftest import NUMERIC_DTYPES -def test_get_buffer_info_numeric() -> None: - for dtype in list(pl.FLOAT_DTYPES) + list(pl.INTEGER_DTYPES): - s = pl.Series([1, 2, 3], dtype=dtype) - assert s._get_buffer_info()[0] > 0 +@pytest.mark.parametrize("dtype", NUMERIC_DTYPES) +def test_get_buffer_info_numeric(dtype: pl.DataType) -> None: + s = pl.Series([1, 2, 3], dtype=dtype) + assert s._get_buffer_info()[0] > 0 def test_get_buffer_info_bool() -> None: diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index 16e4cd9d5715..de36443fc3a5 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -9,6 +9,7 @@ import polars as pl from polars.exceptions import DuplicateError from polars.testing import assert_frame_equal +from tests.unit.conftest import INTEGER_DTYPES if TYPE_CHECKING: from pathlib import Path @@ -321,7 +322,7 @@ def test_streaming_group_by_all_numeric_types_stability_8570() -> None: dfc = dfa.join(dfb, how="cross") for keys in [["x", "y"], "z"]: - for dtype in [pl.Boolean, *pl.INTEGER_DTYPES]: + for dtype in [*INTEGER_DTYPES, pl.Boolean]: # the alias checks if the schema is correctly handled dfd = ( dfc.lazy() diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 2ea54868da01..40fde7a52df0 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -10,38 +10,36 @@ from polars import datatypes from polars.datatypes import ( DTYPE_TEMPORAL_UNITS, - DataTypeGroup, Field, Int64, List, Struct, py_type_to_dtype, ) +from polars.datatypes.group import DataTypeGroup +from tests.unit.conftest import DATETIME_DTYPES, NUMERIC_DTYPES if TYPE_CHECKING: - from polars.datatypes import DataTypeClass + from polars.datatypes.classes import DataTypeClass from polars.type_aliases import PolarsDataType -SIMPLE_DTYPES: list[DataTypeClass] = list( - pl.INTEGER_DTYPES # type: ignore[arg-type] - | pl.FLOAT_DTYPES - | { - pl.Boolean, - pl.String, - pl.Binary, - pl.Time, - pl.Date, - pl.Object, - pl.Null, - pl.Unknown, - } -) +SIMPLE_DTYPES: list[DataTypeClass] = [ + *[dt.base_type() for dt in NUMERIC_DTYPES], + pl.Boolean, + pl.String, + pl.Binary, + pl.Time, + pl.Date, + pl.Object, + pl.Null, + pl.Unknown, +] -def test_simple_dtype_init_takes_no_args() -> None: - for dtype in SIMPLE_DTYPES: - with pytest.raises(TypeError): - dtype(10) +@pytest.mark.parametrize("dtype", SIMPLE_DTYPES) +def test_simple_dtype_init_takes_no_args(dtype: DataTypeClass) -> None: + with pytest.raises(TypeError): + dtype(10) def test_simple_dtype_init_returns_instance() -> None: @@ -55,7 +53,7 @@ def test_complex_dtype_init_returns_instance() -> None: assert dtype.time_unit == "us" -def test_dtype_temporal_units() -> None: +def test_dtype_time_units() -> None: # check (in)equality behaviour of temporal types that take units for time_unit in DTYPE_TEMPORAL_UNITS: assert pl.Datetime == pl.Datetime(time_unit) @@ -85,7 +83,7 @@ def test_dtype_base_type() -> None: pl.Struct([pl.Field("a", pl.Int64), pl.Field("b", pl.Boolean)]).base_type() is pl.Struct ) - for dtype in pl.DATETIME_DTYPES: + for dtype in DATETIME_DTYPES: assert dtype.base_type() is pl.Datetime diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 49eecad1122e..ebd3ec4e73ff 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -21,6 +21,7 @@ SchemaFieldNotFoundError, StructFieldNotFoundError, ) +from tests.unit.conftest import TEMPORAL_DTYPES if TYPE_CHECKING: from polars.type_aliases import ConcatMethod @@ -78,15 +79,17 @@ def test_error_on_invalid_by_in_asof_join() -> None: df1.join_asof(df2, on="b", by=["a", "c"]) -def test_error_on_invalid_series_init() -> None: - for dtype in pl.TEMPORAL_DTYPES: - py_type = dtype_to_py_type(dtype) - with pytest.raises( - TypeError, - match=f"'float' object cannot be interpreted as a {py_type.__name__!r}", - ): - pl.Series([1.5, 2.0, 3.75], dtype=dtype) +@pytest.mark.parametrize("dtype", TEMPORAL_DTYPES) +def test_error_on_invalid_series_init(dtype: pl.DataType) -> None: + py_type = dtype_to_py_type(dtype) + with pytest.raises( + TypeError, + match=f"'float' object cannot be interpreted as a {py_type.__name__!r}", + ): + pl.Series([1.5, 2.0, 3.75], dtype=dtype) + +def test_error_on_invalid_series_init2() -> None: with pytest.raises(TypeError, match="unexpected value"): pl.Series([1.5, 2.0, 3.75], dtype=pl.Int32) diff --git a/py-polars/tests/unit/test_expansion.py b/py-polars/tests/unit/test_expansion.py index 78394075e7bb..795d5511cc50 100644 --- a/py-polars/tests/unit/test_expansion.py +++ b/py-polars/tests/unit/test_expansion.py @@ -5,8 +5,8 @@ import pytest import polars as pl -from polars import NUMERIC_DTYPES from polars.testing import assert_frame_equal +from tests.unit.conftest import NUMERIC_DTYPES def test_regex_exclude() -> None: diff --git a/py-polars/tests/unit/test_init.py b/py-polars/tests/unit/test_init.py index 1813065c8f9e..cf3dda945bb2 100644 --- a/py-polars/tests/unit/test_init.py +++ b/py-polars/tests/unit/test_init.py @@ -20,3 +20,10 @@ def test_init_exceptions_deprecated() -> None: msg = "nope" with pytest.raises(ComputeError, match=msg): raise exc(msg) + + +def test_dtype_groups_deprecated() -> None: + with pytest.deprecated_call(match="`INTEGER_DTYPES` is deprecated."): + dtypes = pl.INTEGER_DTYPES + + assert pl.Int8 in dtypes diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index 347555bac6ad..6657dd051d0a 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -12,6 +12,7 @@ from polars.selectors import expand_selector, is_selector from polars.testing import assert_frame_equal from polars.type_aliases import SelectorType +from tests.unit.conftest import INTEGER_DTYPES, TEMPORAL_DTYPES if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -107,14 +108,14 @@ def test_selector_by_dtype(df: pl.DataFrame) -> None: } ) assert df.select( - ~cs.by_dtype(pl.INTEGER_DTYPES, pl.TEMPORAL_DTYPES) - ).schema == OrderedDict( + ~cs.by_dtype(*INTEGER_DTYPES, *TEMPORAL_DTYPES) + ).schema == pl.Schema( { - "cde": pl.Float64, - "def": pl.Float32, - "eee": pl.Boolean, - "fgg": pl.Boolean, - "qqR": pl.String, + "cde": pl.Float64(), + "def": pl.Float32(), + "eee": pl.Boolean(), + "fgg": pl.Boolean(), + "qqR": pl.String(), } ) assert df.select(cs.by_dtype()).schema == {}