From fb01aa4236b4cc270b03d13961a1ae64a702408d Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 26 Jun 2024 10:43:00 +0200 Subject: [PATCH] refactor(python): Refactor parsing of data type inputs to Polars data types (#17164) --- .../polars/_utils/construction/dataframe.py | 11 +- .../polars/_utils/construction/series.py | 15 +- py-polars/polars/datatypes/__init__.py | 6 +- py-polars/polars/datatypes/_parse.py | 177 +++++++++++++++ py-polars/polars/datatypes/classes.py | 6 +- py-polars/polars/datatypes/convert.py | 154 +------------ py-polars/polars/expr/expr.py | 10 +- py-polars/polars/expr/string.py | 4 +- py-polars/polars/io/csv/batched_reader.py | 4 +- py-polars/polars/io/csv/functions.py | 7 +- py-polars/polars/io/database/_inference.py | 4 +- py-polars/polars/lazyframe/frame.py | 4 +- py-polars/polars/series/series.py | 15 +- .../unit/constructors/test_constructors.py | 209 +++++++++--------- .../tests/unit/constructors/test_dataframe.py | 2 +- py-polars/tests/unit/datatypes/test_parse.py | 138 ++++++++++++ py-polars/tests/unit/io/test_csv.py | 4 - py-polars/tests/unit/test_datatypes.py | 6 +- py-polars/tests/unit/test_errors.py | 4 +- 19 files changed, 473 insertions(+), 307 deletions(-) create mode 100644 py-polars/polars/datatypes/_parse.py create mode 100644 py-polars/tests/unit/datatypes/test_parse.py diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index 7f6a41d2fe97d..7f15c38cb90de 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -42,7 +42,8 @@ Struct, Unknown, is_polars_dtype, - py_type_to_dtype, + parse_into_dtype, + try_parse_into_dtype, ) from polars.dependencies import ( _NUMPY_AVAILABLE, @@ -192,7 +193,7 @@ def _normalize_dtype(dtype: Any) -> PolarsDataType: if is_polars_dtype(dtype, include_unknown=True): return dtype else: - return py_type_to_dtype(dtype) + return parse_into_dtype(dtype) def _parse_schema_overrides( schema_overrides: SchemaDict | None = None, @@ -649,14 +650,14 @@ def _sequence_of_tuple_to_pydf( orient: Orientation | None, infer_schema_length: int | None, ) -> PyDataFrame: - # infer additional meta information if named tuple + # infer additional meta information if namedtuple if is_namedtuple(first_element.__class__): if schema is None: schema = first_element._fields # type: ignore[attr-defined] annotations = getattr(first_element, "__annotations__", None) if annotations and len(annotations) == len(schema): schema = [ - (name, py_type_to_dtype(tp, raise_unmatched=False)) + (name, try_parse_into_dtype(tp)) for name, tp in first_element.__annotations__.items() ] if orient is None: @@ -896,7 +897,7 @@ def _establish_dataclass_or_model_schema( else: column_names = [] overrides = { - col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown) + col: (try_parse_into_dtype(tp) or Unknown) for col, tp in try_get_type_hints(first_element.__class__).items() if ((col in model_fields) if model_fields else (col != "__slots__")) } diff --git a/py-polars/polars/_utils/construction/series.py b/py-polars/polars/_utils/construction/series.py index 660147574465d..a045e34cdd730 100644 --- a/py-polars/polars/_utils/construction/series.py +++ b/py-polars/polars/_utils/construction/series.py @@ -43,7 +43,8 @@ dtype_to_py_type, is_polars_dtype, numpy_char_code_to_dtype, - py_type_to_dtype, + parse_into_dtype, + try_parse_into_dtype, ) from polars.datatypes.constructor import ( numpy_type_to_constructor, @@ -114,7 +115,7 @@ def sequence_to_pyseries( # * if the values are ISO-8601 strings, init then convert via strptime. # * if the values are floats/other dtypes, this is an error. if dtype in py_temporal_types and isinstance(value, int): - dtype = py_type_to_dtype(dtype) # construct from integer + dtype = parse_into_dtype(dtype) # construct from integer elif ( dtype in pl_temporal_types or type(dtype) in pl_temporal_types ) and not isinstance(value, int): @@ -167,15 +168,11 @@ def sequence_to_pyseries( # temporal branch if python_dtype in py_temporal_types: if dtype is None: - dtype = py_type_to_dtype(python_dtype) # construct from integer + dtype = parse_into_dtype(python_dtype) # construct from integer elif dtype in py_temporal_types: - dtype = py_type_to_dtype(dtype) + dtype = parse_into_dtype(dtype) - values_dtype = ( - None - if value is None - else py_type_to_dtype(type(value), raise_unmatched=False) - ) + values_dtype = None if value is None else try_parse_into_dtype(type(value)) if values_dtype is not None and values_dtype.is_float(): msg = f"'float' object cannot be interpreted as a {python_dtype.__name__!r}" raise TypeError( diff --git a/py-polars/polars/datatypes/__init__.py b/py-polars/polars/datatypes/__init__.py index fc36c47de77bf..38188d3f767b0 100644 --- a/py-polars/polars/datatypes/__init__.py +++ b/py-polars/polars/datatypes/__init__.py @@ -1,3 +1,4 @@ +from polars.datatypes._parse import parse_into_dtype, try_parse_into_dtype from polars.datatypes.classes import ( Array, Binary, @@ -49,7 +50,6 @@ maybe_cast, numpy_char_code_to_dtype, py_type_to_arrow_type, - py_type_to_dtype, supported_numpy_char_code, unpack_dtypes, ) @@ -103,7 +103,9 @@ "maybe_cast", "numpy_char_code_to_dtype", "py_type_to_arrow_type", - "py_type_to_dtype", "supported_numpy_char_code", "unpack_dtypes", + # _parse + "parse_into_dtype", + "try_parse_into_dtype", ] diff --git a/py-polars/polars/datatypes/_parse.py b/py-polars/polars/datatypes/_parse.py new file mode 100644 index 0000000000000..84cf3049c1dcc --- /dev/null +++ b/py-polars/polars/datatypes/_parse.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import functools +import re +import sys +from datetime import date, datetime, time, timedelta +from decimal import Decimal as PyDecimal +from typing import TYPE_CHECKING, Any, ForwardRef, NoReturn, Union, get_args + +from polars.datatypes.classes import ( + Binary, + Boolean, + Date, + Datetime, + Decimal, + Duration, + Float64, + Int64, + List, + Null, + Object, + String, + Time, +) +from polars.datatypes.convert import is_polars_dtype + +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict + + +UnionTypeOld = type(Union[int, str]) +if sys.version_info >= (3, 10): + from types import NoneType, UnionType +else: # pragma: no cover + # Define equivalent for older Python versions + NoneType = type(None) + UnionType = UnionTypeOld + + +def parse_into_dtype(input: Any) -> PolarsDataType: + """ + Parse an input into a Polars data type. + + Raises + ------ + TypeError + If the input cannot be parsed into a Polars data type. + """ + if is_polars_dtype(input): + return input + elif isinstance(input, ForwardRef): + return _parse_forward_ref_into_dtype(input) + elif isinstance(input, (UnionType, UnionTypeOld)): + return _parse_union_type_into_dtype(input) + else: + return parse_py_type_into_dtype(input) + + +def try_parse_into_dtype(input: Any) -> PolarsDataType | None: + """Try parsing an input into a Polars data type, returning None on failure.""" + try: + return parse_into_dtype(input) + except TypeError: + return None + + +@functools.lru_cache(16) +def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsDataType: + """Convert Python data type to Polars data type.""" + if input is int: + return Int64() + elif input is float: + return Float64() + elif input is str: + return String() + elif input is bool: + return Boolean() + elif input is date: + return Date() + elif input is datetime: + return Datetime("us") + elif input is timedelta: + return Duration + elif input is time: + return Time() + elif input is PyDecimal: + return Decimal + elif input is bytes: + return Binary() + elif input is object: + return Object() + elif input is NoneType: + return Null() + elif input is list or input is tuple: + return List + + elif hasattr(input, "__origin__") and hasattr(input, "__args__"): + return _parse_generic_into_dtype(input) + + else: + _raise_on_invalid_dtype(input) + + +def _parse_generic_into_dtype(input: Any) -> PolarsDataType: + """Parse a generic type into a Polars data type.""" + base_type = input.__origin__ + if base_type not in (tuple, list): + _raise_on_invalid_dtype(input) + + inner_types = input.__args__ + inner_type = inner_types[0] + if len(inner_types) > 1: + all_equal = all(t in (inner_type, ...) for t in inner_types) + if not all_equal: + _raise_on_invalid_dtype(input) + + inner_type = inner_types[0] + inner_dtype = parse_py_type_into_dtype(inner_type) + return List(inner_dtype) + + +PY_TYPE_STR_TO_DTYPE: SchemaDict = { + "int": Int64(), + "float": Float64(), + "bool": Boolean(), + "str": String(), + "bytes": Binary(), + "date": Date(), + "time": Time(), + "datetime": Datetime("us"), + "object": Object(), + "NoneType": Null(), + "timedelta": Duration, + "Decimal": Decimal, + "list": List, + "tuple": List, +} + + +def _parse_forward_ref_into_dtype(input: ForwardRef) -> PolarsDataType: + """Parse a ForwardRef into a Polars data type.""" + annotation = input.__forward_arg__ + + # Strip "optional" designation - Polars data types are always nullable + formatted = re.sub(r"(^None \|)|(\| None$)", "", annotation).strip() + + try: + return PY_TYPE_STR_TO_DTYPE[formatted] + except KeyError: + _raise_on_invalid_dtype(input) + + +def _parse_union_type_into_dtype(input: Any) -> PolarsDataType: + """ + Parse a union of types into a Polars data type. + + Unions of multiple non-null types (e.g. `int | float`) are not supported. + + Parameters + ---------- + input + A union type, e.g. `str | None` (new syntax) or `Union[str, None]` (old syntax). + """ + # Strip "optional" designation - Polars data types are always nullable + inner_types = [tp for tp in get_args(input) if tp is not NoneType] + + if len(inner_types) != 1: + _raise_on_invalid_dtype(input) + + input = inner_types[0] + return parse_into_dtype(input) + + +def _raise_on_invalid_dtype(input: Any) -> NoReturn: + """Raise an informative error if the input could not be parsed.""" + msg = f"cannot parse input of type {type(input).__name__!r} into Polars data type: {input!r}" + raise TypeError(msg) from None diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index cf2be33f2cbc2..6ab3630ade012 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -604,7 +604,7 @@ class List(NestedType): inner: PolarsDataType def __init__(self, inner: PolarsDataType | PythonDataType): - self.inner = polars.datatypes.py_type_to_dtype(inner) + self.inner = polars.datatypes.parse_into_dtype(inner) def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] # This equality check allows comparison of type classes and type instances. @@ -675,7 +675,7 @@ def __init__( msg = "Array constructor is missing the required argument `shape`" raise TypeError(msg) - inner_parsed = polars.datatypes.py_type_to_dtype(inner) + inner_parsed = polars.datatypes.parse_into_dtype(inner) inner_shape = inner_parsed.shape if isinstance(inner_parsed, Array) else () if isinstance(shape, int): @@ -754,7 +754,7 @@ class Field: def __init__(self, name: str, dtype: PolarsDataType): self.name = name - self.dtype = polars.datatypes.py_type_to_dtype(dtype) + self.dtype = polars.datatypes.parse_into_dtype(dtype) def __eq__(self, other: Field) -> bool: # type: ignore[override] return (self.name == other.name) & (self.dtype == other.dtype) diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 8604cc2365c4f..5ab2af72e1b64 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -6,18 +6,9 @@ import sys from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal -from typing import ( - TYPE_CHECKING, - Any, - Collection, - ForwardRef, - Optional, - Union, - get_args, - overload, -) +from typing import TYPE_CHECKING, Any, Collection, Optional, Union -from polars.datatypes import ( +from polars.datatypes.classes import ( Array, Binary, Boolean, @@ -63,97 +54,24 @@ UnionType = type(Union[int, float]) if TYPE_CHECKING: - from typing import Literal - - from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict, TimeUnit + from polars.type_aliases import PolarsDataType, PythonDataType, TimeUnit if sys.version_info >= (3, 10): from typing import TypeGuard else: from typing_extensions import TypeGuard -PY_STR_TO_DTYPE: SchemaDict = { - "float": Float64, - "int": Int64, - "str": String, - "bool": Boolean, - "date": Date, - "datetime": Datetime("us"), - "timedelta": Duration("us"), - "time": Time, - "list": List, - "tuple": List, - "Decimal": Decimal, - "bytes": Binary, - "object": Object, - "NoneType": Null, -} - - -@functools.lru_cache(16) -def _map_py_type_to_dtype( - python_dtype: PythonDataType | type[object], -) -> PolarsDataType: - """Convert Python data type to Polars data type.""" - if python_dtype is float: - return Float64 - if python_dtype is int: - return Int64 - if python_dtype is str: - return String - if python_dtype is bool: - return Boolean - if issubclass(python_dtype, datetime): - # `datetime` is a subclass of `date`, - # so need to check `datetime` first - return Datetime("us") - if issubclass(python_dtype, date): - return Date - if python_dtype is timedelta: - return Duration - if python_dtype is time: - return Time - if python_dtype is list: - return List - if python_dtype is tuple: - return List - if python_dtype is PyDecimal: - return Decimal - if python_dtype is bytes: - return Binary - if python_dtype is object: - return Object - if python_dtype is None.__class__: - return Null - - # cover generic typing aliases, such as 'list[str]' - if hasattr(python_dtype, "__origin__") and hasattr(python_dtype, "__args__"): - base_type = python_dtype.__origin__ - if base_type is not None: - dtype = _map_py_type_to_dtype(base_type) - nested = python_dtype.__args__ - if len(nested) == 1: - nested = nested[0] - return ( - dtype if nested is None else dtype(_map_py_type_to_dtype(nested)) # type: ignore[operator] - ) - - msg = f"unrecognised Python type: {python_dtype!r}" - raise TypeError(msg) - def is_polars_dtype( dtype: Any, *, include_unknown: bool = False ) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" - try: - if dtype == Unknown: - # does not represent a realizable dtype, so ignore by default - return include_unknown - else: - return isinstance(dtype, (DataType, DataTypeClass)) - except TypeError: - return False + is_dtype = isinstance(dtype, (DataType, DataTypeClass)) + + if not include_unknown: + return is_dtype and dtype != Unknown + else: + return is_dtype def unpack_dtypes( @@ -342,60 +260,6 @@ def dtype_to_py_type(dtype: PolarsDataType) -> PythonDataType: raise NotImplementedError(msg) from None -@overload -def py_type_to_dtype( - data_type: Any, *, raise_unmatched: Literal[True] = ... -) -> PolarsDataType: ... - - -@overload -def py_type_to_dtype( - data_type: Any, *, raise_unmatched: Literal[False] -) -> PolarsDataType | None: ... - - -def py_type_to_dtype( - data_type: Any, *, raise_unmatched: bool = True, allow_strings: bool = False -) -> PolarsDataType | None: - """Convert a Python dtype (or type annotation) to a Polars dtype.""" - if isinstance(data_type, ForwardRef): - annotation = data_type.__forward_arg__ - data_type = ( - PY_STR_TO_DTYPE.get( - re.sub(r"(^None \|)|(\| None$)", "", annotation).strip(), data_type - ) - if isinstance(annotation, str) # type: ignore[redundant-expr] - else annotation - ) - elif type(data_type).__name__ == "InitVar": - data_type = data_type.type - - if is_polars_dtype(data_type): - return data_type - - elif isinstance(data_type, (OptionType, UnionType)): - # not exhaustive; handles the common "type | None" case, but - # should probably pick appropriate supertype when n_types > 1? - possible_types = [tp for tp in get_args(data_type) if tp is not NoneType] - if len(possible_types) == 1: - data_type = possible_types[0] - - elif allow_strings and isinstance(data_type, str): - data_type = DataTypeMappings.REPR_TO_DTYPE.get( - re.sub(r"^(?:dataclasses\.)?InitVar\[(.+)\]$", r"\1", data_type), - data_type, - ) - if is_polars_dtype(data_type): - return data_type - try: - return _map_py_type_to_dtype(data_type) - except (KeyError, TypeError): # pragma: no cover - if raise_unmatched: - msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})" - raise ValueError(msg) from None - return None - - def py_type_to_arrow_type(dtype: PythonDataType) -> pa.lib.DataType: """Convert a Python dtype to an Arrow dtype.""" try: diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 69141750b2803..e9b0adef34d5d 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -46,11 +46,7 @@ sphinx_accessor, warn_null_comparison, ) -from polars.datatypes import ( - Int64, - is_polars_dtype, - py_type_to_dtype, -) +from polars.datatypes import Int64, is_polars_dtype, parse_into_dtype from polars.dependencies import _check_for_numpy from polars.dependencies import numpy as np from polars.exceptions import CustomUFuncWarning, PolarsInefficientMapWarning @@ -1748,7 +1744,7 @@ def cast( │ 3.0 ┆ 6 │ └─────┴─────┘ """ - dtype = py_type_to_dtype(dtype) + dtype = parse_into_dtype(dtype) return self._from_pyexpr(self._pyexpr.cast(dtype, strict, wrap_numerical)) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Expr: @@ -4378,7 +4374,7 @@ def map_batches( """ if return_dtype is not None: - return_dtype = py_type_to_dtype(return_dtype) + return_dtype = parse_into_dtype(return_dtype) return self._from_pyexpr( self._pyexpr.map_batches( diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 57e6af7db6035..995406fedf330 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -12,7 +12,7 @@ from polars._utils.parse import parse_into_expression from polars._utils.various import find_stacklevel from polars._utils.wrap import wrap_expr -from polars.datatypes import Date, Datetime, Time, py_type_to_dtype +from polars.datatypes import Date, Datetime, Time, parse_into_dtype from polars.datatypes.constants import N_INFER_DEFAULT from polars.exceptions import ChronoFormatWarning @@ -1236,7 +1236,7 @@ def json_decode( └─────────────────────┴─────────────┘ """ if dtype is not None: - dtype = py_type_to_dtype(dtype) + dtype = parse_into_dtype(dtype) return wrap_expr(self._pyexpr.str_json_decode(dtype, infer_schema_length)) def json_path_match(self, json_path: IntoExprColumn) -> Expr: diff --git a/py-polars/polars/io/csv/batched_reader.py b/py-polars/polars/io/csv/batched_reader.py index fc80e77424c01..baf4a6c6f9364 100644 --- a/py-polars/polars/io/csv/batched_reader.py +++ b/py-polars/polars/io/csv/batched_reader.py @@ -8,7 +8,7 @@ normalize_filepath, ) from polars._utils.wrap import wrap_df -from polars.datatypes import N_INFER_DEFAULT, py_type_to_dtype +from polars.datatypes import N_INFER_DEFAULT, parse_into_dtype from polars.io._utils import parse_columns_arg, parse_row_index_args from polars.io.csv._utils import _update_columns @@ -65,7 +65,7 @@ def __init__( if isinstance(schema_overrides, dict): dtype_list = [] for k, v in schema_overrides.items(): - dtype_list.append((k, py_type_to_dtype(v))) + dtype_list.append((k, parse_into_dtype(v))) elif isinstance(schema_overrides, Sequence): dtype_slice = schema_overrides else: diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index ee630a370f046..1c2634f2281eb 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -13,8 +13,7 @@ normalize_filepath, ) from polars._utils.wrap import wrap_df, wrap_ldf -from polars.datatypes import N_INFER_DEFAULT, String -from polars.datatypes.convert import py_type_to_dtype +from polars.datatypes import N_INFER_DEFAULT, String, parse_into_dtype from polars.io._utils import ( is_glob_pattern, parse_columns_arg, @@ -501,7 +500,7 @@ def _read_csv_impl( if isinstance(schema_overrides, dict): dtype_list = [] for k, v in schema_overrides.items(): - dtype_list.append((k, py_type_to_dtype(v))) + dtype_list.append((k, parse_into_dtype(v))) elif isinstance(schema_overrides, Sequence): dtype_slice = schema_overrides else: @@ -1218,7 +1217,7 @@ def _scan_csv_impl( if schema_overrides is not None: dtype_list = [] for k, v in schema_overrides.items(): - dtype_list.append((k, py_type_to_dtype(v))) + dtype_list.append((k, parse_into_dtype(v))) processed_null_values = _process_null_values(null_values) if isinstance(source, list): diff --git a/py-polars/polars/io/database/_inference.py b/py-polars/polars/io/database/_inference.py index 2a34eeba83140..601e5ddd6514c 100644 --- a/py-polars/polars/io/database/_inference.py +++ b/py-polars/polars/io/database/_inference.py @@ -28,7 +28,7 @@ UInt32, UInt64, ) -from polars.datatypes.convert import _map_py_type_to_dtype +from polars.datatypes._parse import parse_py_type_into_dtype from polars.datatypes.group import ( INTEGER_DTYPES, UNSIGNED_INTEGER_DTYPES, @@ -209,7 +209,7 @@ def _infer_dtype_from_cursor_description( if isclass(type_code): # python types, eg: int, float, str, etc with suppress(TypeError): - dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type] + dtype = parse_py_type_into_dtype(type_code) # type: ignore[arg-type] elif isinstance(type_code, str): # database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 01e56a4d2adc0..ba13a7dcbc099 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -72,7 +72,7 @@ UInt64, Unknown, is_polars_dtype, - py_type_to_dtype, + parse_into_dtype, ) from polars.datatypes.group import DataTypeGroup from polars.dependencies import import_optional, subprocess @@ -2786,7 +2786,7 @@ def cast( ): c = by_dtype(c) # type: ignore[arg-type] - dtype = py_type_to_dtype(dtype) + dtype = parse_into_dtype(dtype) cast_map.update( {c: dtype} if isinstance(c, str) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 7aead22f05f77..a238dbce75c09 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -81,7 +81,7 @@ is_polars_dtype, maybe_cast, numpy_char_code_to_dtype, - py_type_to_dtype, + parse_into_dtype, supported_numpy_char_code, ) from polars.datatypes._utils import dtype_to_init_repr @@ -267,14 +267,7 @@ def __init__( if dtype == Unknown: dtype = None elif dtype is not None and not is_polars_dtype(dtype): - # Raise early error on invalid dtype - if not is_polars_dtype( - pl_dtype := py_type_to_dtype(dtype, raise_unmatched=False) - ): - msg = f"given dtype: {dtype!r} is not a valid Polars data type and cannot be converted into one" - raise ValueError(msg) - else: - dtype = pl_dtype + dtype = parse_into_dtype(dtype) # Handle case where values are passed as the first argument original_name: str | None = None @@ -3953,7 +3946,7 @@ def cast( ] """ # Do not dispatch cast as it is expensive and used in other functions. - dtype = py_type_to_dtype(dtype) + dtype = parse_into_dtype(dtype) return self._from_pyseries(self._s.cast(dtype, strict, wrap_numerical)) def to_physical(self) -> Series: @@ -5259,7 +5252,7 @@ def map_elements( if return_dtype is None: pl_return_dtype = None else: - pl_return_dtype = py_type_to_dtype(return_dtype) + pl_return_dtype = parse_into_dtype(return_dtype) warn_on_inefficient_map(function, columns=[self.name], map_target="series") return self._from_pyseries( diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index 54dcc71ca7769..bd49c0009b6c3 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -176,13 +176,13 @@ def test_init_dict() -> None: def test_error_string_dtypes() -> None: - with pytest.raises(ValueError, match="cannot infer dtype"): + with pytest.raises(TypeError, match="cannot parse input"): pl.DataFrame( data={"x": [1, 2], "y": [3, 4], "z": [5, 6]}, schema={"x": "i16", "y": "i32", "z": "f32"}, # type: ignore[dict-item] ) - with pytest.raises(ValueError, match="not a valid Polars data type"): + with pytest.raises(TypeError, match="cannot parse input"): pl.Series("n", [1, 2, 3], dtype="f32") # type: ignore[arg-type] @@ -326,118 +326,121 @@ class Test(NamedTuple): assert df.rows() == test_data -def test_init_structured_objects_nested() -> None: - for Foo, Bar, Baz in ( +@pytest.mark.parametrize( + ("foo", "bar", "baz"), + [ (_TestFooDC, _TestBarDC, _TestBazDC), (_TestFooPD, _TestBarPD, _TestBazPD), (_TestFooNT, _TestBarNT, _TestBazNT), - ): - data = [ - Foo( - x=100, - y=Bar( - a="hello", - b=800, - c=Baz(d=datetime(2023, 4, 12, 10, 30), e=-10.5, f="world"), + ], +) +def test_init_structured_objects_nested(foo: Any, bar: Any, baz: Any) -> None: + data = [ + foo( + x=100, + y=bar( + a="hello", + b=800, + c=baz(d=datetime(2023, 4, 12, 10, 30), e=-10.5, f="world"), + ), + ) + ] + df = pl.DataFrame(data) + # shape: (1, 2) + # ┌─────┬───────────────────────────────────┐ + # │ x ┆ y │ + # │ --- ┆ --- │ + # │ i64 ┆ struct[3] │ + # ╞═════╪═══════════════════════════════════╡ + # │ 100 ┆ {"hello",800,{2023-04-12 10:30:0… │ + # └─────┴───────────────────────────────────┘ + + assert df.schema == { + "x": pl.Int64, + "y": pl.Struct( + [ + pl.Field("a", pl.String), + pl.Field("b", pl.Int64), + pl.Field( + "c", + pl.Struct( + [ + pl.Field("d", pl.Datetime("us")), + pl.Field("e", pl.Float64), + pl.Field("f", pl.String), + ] + ), ), - ) - ] - df = pl.DataFrame(data) - # shape: (1, 2) - # ┌─────┬───────────────────────────────────┐ - # │ x ┆ y │ - # │ --- ┆ --- │ - # │ i64 ┆ struct[3] │ - # ╞═════╪═══════════════════════════════════╡ - # │ 100 ┆ {"hello",800,{2023-04-12 10:30:0… │ - # └─────┴───────────────────────────────────┘ + ] + ), + } + assert df.row(0) == ( + 100, + { + "a": "hello", + "b": 800, + "c": { + "d": datetime(2023, 4, 12, 10, 30), + "e": -10.5, + "f": "world", + }, + }, + ) - assert df.schema == { - "x": pl.Int64, - "y": pl.Struct( - [ - pl.Field("a", pl.String), - pl.Field("b", pl.Int64), - pl.Field( - "c", - pl.Struct( - [ - pl.Field("d", pl.Datetime("us")), - pl.Field("e", pl.Float64), - pl.Field("f", pl.String), - ] - ), + # validate nested schema override + override_struct_schema: dict[str, PolarsDataType] = { + "x": pl.Int16, + "y": pl.Struct( + [ + pl.Field("a", pl.String), + pl.Field("b", pl.Int32), + pl.Field( + name="c", + dtype=pl.Struct( + [ + pl.Field("d", pl.Datetime("ms")), + pl.Field("e", pl.Float32), + pl.Field("f", pl.String), + ] ), - ] - ), + ), + ] + ), + } + for schema, schema_overrides in ( + (None, override_struct_schema), + (override_struct_schema, None), + ): + df = ( + pl.DataFrame(data, schema=schema, schema_overrides=schema_overrides) + .unnest("y") + .unnest("c") + ) + # shape: (1, 6) + # ┌─────┬───────┬─────┬─────────────────────┬───────┬───────┐ + # │ x ┆ a ┆ b ┆ d ┆ e ┆ f │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ i16 ┆ str ┆ i32 ┆ datetime[ms] ┆ f32 ┆ str │ + # ╞═════╪═══════╪═════╪═════════════════════╪═══════╪═══════╡ + # │ 100 ┆ hello ┆ 800 ┆ 2023-04-12 10:30:00 ┆ -10.5 ┆ world │ + # └─────┴───────┴─────┴─────────────────────┴───────┴───────┘ + assert df.schema == { + "x": pl.Int16, + "a": pl.String, + "b": pl.Int32, + "d": pl.Datetime("ms"), + "e": pl.Float32, + "f": pl.String, } assert df.row(0) == ( 100, - { - "a": "hello", - "b": 800, - "c": { - "d": datetime(2023, 4, 12, 10, 30), - "e": -10.5, - "f": "world", - }, - }, + "hello", + 800, + datetime(2023, 4, 12, 10, 30), + -10.5, + "world", ) - # validate nested schema override - override_struct_schema: dict[str, PolarsDataType] = { - "x": pl.Int16, - "y": pl.Struct( - [ - pl.Field("a", pl.String), - pl.Field("b", pl.Int32), - pl.Field( - name="c", - dtype=pl.Struct( - [ - pl.Field("d", pl.Datetime("ms")), - pl.Field("e", pl.Float32), - pl.Field("f", pl.String), - ] - ), - ), - ] - ), - } - for schema, schema_overrides in ( - (None, override_struct_schema), - (override_struct_schema, None), - ): - df = ( - pl.DataFrame(data, schema=schema, schema_overrides=schema_overrides) - .unnest("y") - .unnest("c") - ) - # shape: (1, 6) - # ┌─────┬───────┬─────┬─────────────────────┬───────┬───────┐ - # │ x ┆ a ┆ b ┆ d ┆ e ┆ f │ - # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - # │ i16 ┆ str ┆ i32 ┆ datetime[ms] ┆ f32 ┆ str │ - # ╞═════╪═══════╪═════╪═════════════════════╪═══════╪═══════╡ - # │ 100 ┆ hello ┆ 800 ┆ 2023-04-12 10:30:00 ┆ -10.5 ┆ world │ - # └─────┴───────┴─────┴─────────────────────┴───────┴───────┘ - assert df.schema == { - "x": pl.Int16, - "a": pl.String, - "b": pl.Int32, - "d": pl.Datetime("ms"), - "e": pl.Float32, - "f": pl.String, - } - assert df.row(0) == ( - 100, - "hello", - 800, - datetime(2023, 4, 12, 10, 30), - -10.5, - "world", - ) - def test_dataclasses_initvar_typing() -> None: @dataclasses.dataclass diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py index f1ae8084e9dc6..b193c99f5011c 100644 --- a/py-polars/tests/unit/constructors/test_dataframe.py +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -181,7 +181,7 @@ def test_custom_schema() -> None: df = pl.DataFrame(schema=CustomSchema(bool=pl.Boolean, misc=pl.UInt8)) assert df.schema == OrderedDict([("bool", pl.Boolean), ("misc", pl.UInt8)]) - with pytest.raises(ValueError): + with pytest.raises(TypeError): pl.DataFrame(schema=CustomSchema(bool="boolean", misc="unsigned int")) diff --git a/py-polars/tests/unit/datatypes/test_parse.py b/py-polars/tests/unit/datatypes/test_parse.py new file mode 100644 index 0000000000000..b448633fb92b1 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_parse.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import ( + TYPE_CHECKING, + Any, + Dict, + ForwardRef, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import pytest + +import polars as pl +from polars.datatypes._parse import ( + _parse_forward_ref_into_dtype, + _parse_generic_into_dtype, + _parse_union_type_into_dtype, + parse_into_dtype, + parse_py_type_into_dtype, +) + +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType + + +def assert_dtype_equal(left: PolarsDataType, right: PolarsDataType) -> None: + assert left == right + assert type(left) == type(right) + assert hash(left) == hash(right) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (pl.Int8(), pl.Int8()), + (list, pl.List), + ], +) +def test_parse_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = parse_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (datetime, pl.Datetime("us")), + (date, pl.Date()), + (type(None), pl.Null()), + (object, pl.Object()), + ], +) +def test_parse_py_type_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = parse_py_type_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (List[int], pl.List(pl.Int64())), + (Tuple[str, ...], pl.List(pl.String())), + (Tuple[datetime, datetime], pl.List(pl.Datetime("us"))), + ], +) +def test_parse_generic_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = _parse_generic_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + "input", + [ + Dict[str, float], + Tuple[int, str], + Tuple[int, float, float], + ], +) +def test_parse_generic_into_dtype_invalid(input: Any) -> None: + with pytest.raises(TypeError): + _parse_generic_into_dtype(input) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (ForwardRef("date"), pl.Date()), + (ForwardRef("int | None"), pl.Int64()), + (ForwardRef("None | float"), pl.Float64()), + ], +) +def test_parse_forward_ref_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = _parse_forward_ref_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (Optional[int], pl.Int64()), + (Optional[pl.String], pl.String), + (Union[float, None], pl.Float64()), + ], +) +def test_parse_union_type_into_dtype(input: Any, expected: PolarsDataType) -> None: + result = _parse_union_type_into_dtype(input) + assert_dtype_equal(result, expected) + + +@pytest.mark.parametrize( + "input", + [ + Union[int, float], + Optional[Union[int, str]], + ], +) +def test_parse_union_type_into_dtype_invalid(input: Any) -> None: + with pytest.raises(TypeError): + _parse_union_type_into_dtype(input) + + +def test_parse_dtype_namedtuple_fields() -> None: + # Utilizes ForwardRef parsing + + class MyTuple(NamedTuple): + a: str + b: int + c: str | None = None + + schema = {c: parse_into_dtype(a) for c, a in MyTuple.__annotations__.items()} + + expected = pl.Schema({"a": pl.String(), "b": pl.Int64(), "c": pl.String()}) + assert schema == expected diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index a8decf3981186..8c61e1981b431 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1744,8 +1744,6 @@ def test_read_csv_comments_on_top_with_schema_11667() -> None: def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: - # The capsys fixture allows pytest to access stdout/stderr. See - # https://docs.pytest.org/en/7.1.x/how-to/capture-stdout-stderr.html df = pl.DataFrame( { "numbers": [1, 2, 3], @@ -1753,8 +1751,6 @@ def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: "dates": [date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)], } ) - - # pytest hijacks sys.stdout and changes its type, which causes mypy failure df.write_csv(sys.stdout) captured = capsys.readouterr() assert captured.out == ( diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 40fde7a52df00..89a8ff4079c6b 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -14,7 +14,7 @@ Int64, List, Struct, - py_type_to_dtype, + parse_into_dtype, ) from polars.datatypes.group import DataTypeGroup from tests.unit.conftest import DATETIME_DTYPES, NUMERIC_DTYPES @@ -66,8 +66,8 @@ def test_dtype_time_units() -> None: assert pl.Duration("ns") != pl.Duration("us") # check timeunit from pytype - assert py_type_to_dtype(datetime) == pl.Datetime("us") - assert py_type_to_dtype(timedelta) == pl.Duration + assert parse_into_dtype(datetime) == pl.Datetime("us") + assert parse_into_dtype(timedelta) == pl.Duration with pytest.raises(ValueError, match="invalid `time_unit`"): pl.Datetime("?") # type: ignore[arg-type] diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index dc89f0e43f2ed..ad8152b24a7f9 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -325,8 +325,8 @@ def test_datetime_time_add_err() -> None: def test_invalid_dtype() -> None: with pytest.raises( - ValueError, - match=r"given dtype: 'mayonnaise' is not a valid Polars data type and cannot be converted into one", + TypeError, + match="cannot parse input of type 'str' into Polars data type: 'mayonnaise'", ): pl.Series([1, 2], dtype="mayonnaise") # type: ignore[arg-type]