Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(python): Refactor parsing of data type inputs to Polars data types #17164

Merged
merged 14 commits into from
Jun 26, 2024
11 changes: 6 additions & 5 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
Struct,
Unknown,
is_polars_dtype,
py_type_to_dtype,
parse_into_dtype,
try_parse_into_dtype,
)
from polars.dependencies import (
_NUMPY_AVAILABLE,
Expand Down Expand Up @@ -192,7 +193,7 @@ def _normalize_dtype(dtype: Any) -> PolarsDataType:
if is_polars_dtype(dtype, include_unknown=True):
return dtype
else:
return py_type_to_dtype(dtype)
return parse_into_dtype(dtype)

def _parse_schema_overrides(
schema_overrides: SchemaDict | None = None,
Expand Down Expand Up @@ -649,14 +650,14 @@ def _sequence_of_tuple_to_pydf(
orient: Orientation | None,
infer_schema_length: int | None,
) -> PyDataFrame:
# infer additional meta information if named tuple
# infer additional meta information if namedtuple
if is_namedtuple(first_element.__class__):
if schema is None:
schema = first_element._fields # type: ignore[attr-defined]
annotations = getattr(first_element, "__annotations__", None)
if annotations and len(annotations) == len(schema):
schema = [
(name, py_type_to_dtype(tp, raise_unmatched=False))
(name, try_parse_into_dtype(tp))
for name, tp in first_element.__annotations__.items()
]
if orient is None:
Expand Down Expand Up @@ -896,7 +897,7 @@ def _establish_dataclass_or_model_schema(
else:
column_names = []
overrides = {
col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown)
col: (try_parse_into_dtype(tp) or Unknown)
for col, tp in try_get_type_hints(first_element.__class__).items()
if ((col in model_fields) if model_fields else (col != "__slots__"))
}
Expand Down
15 changes: 6 additions & 9 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
dtype_to_py_type,
is_polars_dtype,
numpy_char_code_to_dtype,
py_type_to_dtype,
parse_into_dtype,
try_parse_into_dtype,
)
from polars.datatypes.constructor import (
numpy_type_to_constructor,
Expand Down Expand Up @@ -114,7 +115,7 @@ def sequence_to_pyseries(
# * if the values are ISO-8601 strings, init then convert via strptime.
# * if the values are floats/other dtypes, this is an error.
if dtype in py_temporal_types and isinstance(value, int):
dtype = py_type_to_dtype(dtype) # construct from integer
dtype = parse_into_dtype(dtype) # construct from integer
elif (
dtype in pl_temporal_types or type(dtype) in pl_temporal_types
) and not isinstance(value, int):
Expand Down Expand Up @@ -167,15 +168,11 @@ def sequence_to_pyseries(
# temporal branch
if python_dtype in py_temporal_types:
if dtype is None:
dtype = py_type_to_dtype(python_dtype) # construct from integer
dtype = parse_into_dtype(python_dtype) # construct from integer
elif dtype in py_temporal_types:
dtype = py_type_to_dtype(dtype)
dtype = parse_into_dtype(dtype)

values_dtype = (
None
if value is None
else py_type_to_dtype(type(value), raise_unmatched=False)
)
values_dtype = None if value is None else try_parse_into_dtype(type(value))
if values_dtype is not None and values_dtype.is_float():
msg = f"'float' object cannot be interpreted as a {python_dtype.__name__!r}"
raise TypeError(
Expand Down
6 changes: 4 additions & 2 deletions py-polars/polars/datatypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from polars.datatypes._parse import parse_into_dtype, try_parse_into_dtype
from polars.datatypes.classes import (
Array,
Binary,
Expand Down Expand Up @@ -49,7 +50,6 @@
maybe_cast,
numpy_char_code_to_dtype,
py_type_to_arrow_type,
py_type_to_dtype,
supported_numpy_char_code,
unpack_dtypes,
)
Expand Down Expand Up @@ -103,7 +103,9 @@
"maybe_cast",
"numpy_char_code_to_dtype",
"py_type_to_arrow_type",
"py_type_to_dtype",
"supported_numpy_char_code",
"unpack_dtypes",
# _parse
"parse_into_dtype",
"try_parse_into_dtype",
]
177 changes: 177 additions & 0 deletions py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import functools
import re
import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal as PyDecimal
from typing import TYPE_CHECKING, Any, ForwardRef, NoReturn, Union, get_args

from polars.datatypes.classes import (
Binary,
Boolean,
Date,
Datetime,
Decimal,
Duration,
Float64,
Int64,
List,
Null,
Object,
String,
Time,
)
from polars.datatypes.convert import is_polars_dtype

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict


UnionTypeOld = type(Union[int, str])
if sys.version_info >= (3, 10):
from types import NoneType, UnionType
else: # pragma: no cover
# Define equivalent for older Python versions
NoneType = type(None)
UnionType = UnionTypeOld


def parse_into_dtype(input: Any) -> PolarsDataType:
"""
Parse an input into a Polars data type.

Raises
------
TypeError
If the input cannot be parsed into a Polars data type.
"""
if is_polars_dtype(input):
return input
elif isinstance(input, ForwardRef):
return _parse_forward_ref_into_dtype(input)
elif isinstance(input, (UnionType, UnionTypeOld)):
return _parse_union_type_into_dtype(input)
else:
return parse_py_type_into_dtype(input)


def try_parse_into_dtype(input: Any) -> PolarsDataType | None:
"""Try parsing an input into a Polars data type, returning None on failure."""
try:
return parse_into_dtype(input)
except TypeError:
return None


@functools.lru_cache(16)
def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsDataType:
"""Convert Python data type to Polars data type."""
if input is int:
return Int64()
elif input is float:
return Float64()
elif input is str:
return String()
elif input is bool:
return Boolean()
elif input is date:
return Date()
elif input is datetime:
return Datetime("us")
elif input is timedelta:
return Duration
elif input is time:
return Time()
elif input is PyDecimal:
return Decimal
elif input is bytes:
return Binary()
elif input is object:
return Object()
elif input is NoneType:
return Null()
elif input is list or input is tuple:
return List

elif hasattr(input, "__origin__") and hasattr(input, "__args__"):
return _parse_generic_into_dtype(input)

else:
_raise_on_invalid_dtype(input)


def _parse_generic_into_dtype(input: Any) -> PolarsDataType:
"""Parse a generic type into a Polars data type."""
base_type = input.__origin__
if base_type not in (tuple, list):
_raise_on_invalid_dtype(input)

inner_types = input.__args__
inner_type = inner_types[0]
if len(inner_types) > 1:
all_equal = all(t in (inner_type, ...) for t in inner_types)
if not all_equal:
_raise_on_invalid_dtype(input)

inner_type = inner_types[0]
inner_dtype = parse_py_type_into_dtype(inner_type)
return List(inner_dtype)


PY_TYPE_STR_TO_DTYPE: SchemaDict = {
"int": Int64(),
"float": Float64(),
"bool": Boolean(),
"str": String(),
"bytes": Binary(),
"date": Date(),
"time": Time(),
"datetime": Datetime("us"),
"object": Object(),
"NoneType": Null(),
"timedelta": Duration,
"Decimal": Decimal,
"list": List,
"tuple": List,
}


def _parse_forward_ref_into_dtype(input: ForwardRef) -> PolarsDataType:
"""Parse a ForwardRef into a Polars data type."""
annotation = input.__forward_arg__

# Strip "optional" designation - Polars data types are always nullable
formatted = re.sub(r"(^None \|)|(\| None$)", "", annotation).strip()

try:
return PY_TYPE_STR_TO_DTYPE[formatted]
except KeyError:
_raise_on_invalid_dtype(input)


def _parse_union_type_into_dtype(input: Any) -> PolarsDataType:
"""
Parse a union of types into a Polars data type.

Unions of multiple non-null types (e.g. `int | float`) are not supported.

Parameters
----------
input
A union type, e.g. `str | None` (new syntax) or `Union[str, None]` (old syntax).
"""
# Strip "optional" designation - Polars data types are always nullable
inner_types = [tp for tp in get_args(input) if tp is not NoneType]

if len(inner_types) != 1:
_raise_on_invalid_dtype(input)

input = inner_types[0]
return parse_into_dtype(input)


def _raise_on_invalid_dtype(input: Any) -> NoReturn:
"""Raise an informative error if the input could not be parsed."""
msg = f"cannot parse input of type {type(input).__name__!r} into Polars data type: {input!r}"
raise TypeError(msg) from None
6 changes: 3 additions & 3 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ class List(NestedType):
inner: PolarsDataType

def __init__(self, inner: PolarsDataType | PythonDataType):
self.inner = polars.datatypes.py_type_to_dtype(inner)
self.inner = polars.datatypes.parse_into_dtype(inner)

def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# This equality check allows comparison of type classes and type instances.
Expand Down Expand Up @@ -675,7 +675,7 @@ def __init__(
msg = "Array constructor is missing the required argument `shape`"
raise TypeError(msg)

inner_parsed = polars.datatypes.py_type_to_dtype(inner)
inner_parsed = polars.datatypes.parse_into_dtype(inner)
inner_shape = inner_parsed.shape if isinstance(inner_parsed, Array) else ()

if isinstance(shape, int):
Expand Down Expand Up @@ -754,7 +754,7 @@ class Field:

def __init__(self, name: str, dtype: PolarsDataType):
self.name = name
self.dtype = polars.datatypes.py_type_to_dtype(dtype)
self.dtype = polars.datatypes.parse_into_dtype(dtype)

def __eq__(self, other: Field) -> bool: # type: ignore[override]
return (self.name == other.name) & (self.dtype == other.dtype)
Expand Down
Loading