Skip to content

Commit

Permalink
refactor(python): Refactor parsing of data type inputs to Polars data…
Browse files Browse the repository at this point in the history
… types (pola-rs#17164)
  • Loading branch information
stinodego authored and alexander-beedie committed Jun 26, 2024
1 parent 2c5644e commit fb01aa4
Show file tree
Hide file tree
Showing 19 changed files with 473 additions and 307 deletions.
11 changes: 6 additions & 5 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
Struct,
Unknown,
is_polars_dtype,
py_type_to_dtype,
parse_into_dtype,
try_parse_into_dtype,
)
from polars.dependencies import (
_NUMPY_AVAILABLE,
Expand Down Expand Up @@ -192,7 +193,7 @@ def _normalize_dtype(dtype: Any) -> PolarsDataType:
if is_polars_dtype(dtype, include_unknown=True):
return dtype
else:
return py_type_to_dtype(dtype)
return parse_into_dtype(dtype)

def _parse_schema_overrides(
schema_overrides: SchemaDict | None = None,
Expand Down Expand Up @@ -649,14 +650,14 @@ def _sequence_of_tuple_to_pydf(
orient: Orientation | None,
infer_schema_length: int | None,
) -> PyDataFrame:
# infer additional meta information if named tuple
# infer additional meta information if namedtuple
if is_namedtuple(first_element.__class__):
if schema is None:
schema = first_element._fields # type: ignore[attr-defined]
annotations = getattr(first_element, "__annotations__", None)
if annotations and len(annotations) == len(schema):
schema = [
(name, py_type_to_dtype(tp, raise_unmatched=False))
(name, try_parse_into_dtype(tp))
for name, tp in first_element.__annotations__.items()
]
if orient is None:
Expand Down Expand Up @@ -896,7 +897,7 @@ def _establish_dataclass_or_model_schema(
else:
column_names = []
overrides = {
col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown)
col: (try_parse_into_dtype(tp) or Unknown)
for col, tp in try_get_type_hints(first_element.__class__).items()
if ((col in model_fields) if model_fields else (col != "__slots__"))
}
Expand Down
15 changes: 6 additions & 9 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
dtype_to_py_type,
is_polars_dtype,
numpy_char_code_to_dtype,
py_type_to_dtype,
parse_into_dtype,
try_parse_into_dtype,
)
from polars.datatypes.constructor import (
numpy_type_to_constructor,
Expand Down Expand Up @@ -114,7 +115,7 @@ def sequence_to_pyseries(
# * if the values are ISO-8601 strings, init then convert via strptime.
# * if the values are floats/other dtypes, this is an error.
if dtype in py_temporal_types and isinstance(value, int):
dtype = py_type_to_dtype(dtype) # construct from integer
dtype = parse_into_dtype(dtype) # construct from integer
elif (
dtype in pl_temporal_types or type(dtype) in pl_temporal_types
) and not isinstance(value, int):
Expand Down Expand Up @@ -167,15 +168,11 @@ def sequence_to_pyseries(
# temporal branch
if python_dtype in py_temporal_types:
if dtype is None:
dtype = py_type_to_dtype(python_dtype) # construct from integer
dtype = parse_into_dtype(python_dtype) # construct from integer
elif dtype in py_temporal_types:
dtype = py_type_to_dtype(dtype)
dtype = parse_into_dtype(dtype)

values_dtype = (
None
if value is None
else py_type_to_dtype(type(value), raise_unmatched=False)
)
values_dtype = None if value is None else try_parse_into_dtype(type(value))
if values_dtype is not None and values_dtype.is_float():
msg = f"'float' object cannot be interpreted as a {python_dtype.__name__!r}"
raise TypeError(
Expand Down
6 changes: 4 additions & 2 deletions py-polars/polars/datatypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from polars.datatypes._parse import parse_into_dtype, try_parse_into_dtype
from polars.datatypes.classes import (
Array,
Binary,
Expand Down Expand Up @@ -49,7 +50,6 @@
maybe_cast,
numpy_char_code_to_dtype,
py_type_to_arrow_type,
py_type_to_dtype,
supported_numpy_char_code,
unpack_dtypes,
)
Expand Down Expand Up @@ -103,7 +103,9 @@
"maybe_cast",
"numpy_char_code_to_dtype",
"py_type_to_arrow_type",
"py_type_to_dtype",
"supported_numpy_char_code",
"unpack_dtypes",
# _parse
"parse_into_dtype",
"try_parse_into_dtype",
]
177 changes: 177 additions & 0 deletions py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import functools
import re
import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal as PyDecimal
from typing import TYPE_CHECKING, Any, ForwardRef, NoReturn, Union, get_args

from polars.datatypes.classes import (
Binary,
Boolean,
Date,
Datetime,
Decimal,
Duration,
Float64,
Int64,
List,
Null,
Object,
String,
Time,
)
from polars.datatypes.convert import is_polars_dtype

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict


UnionTypeOld = type(Union[int, str])
if sys.version_info >= (3, 10):
from types import NoneType, UnionType
else: # pragma: no cover
# Define equivalent for older Python versions
NoneType = type(None)
UnionType = UnionTypeOld


def parse_into_dtype(input: Any) -> PolarsDataType:
"""
Parse an input into a Polars data type.
Raises
------
TypeError
If the input cannot be parsed into a Polars data type.
"""
if is_polars_dtype(input):
return input
elif isinstance(input, ForwardRef):
return _parse_forward_ref_into_dtype(input)
elif isinstance(input, (UnionType, UnionTypeOld)):
return _parse_union_type_into_dtype(input)
else:
return parse_py_type_into_dtype(input)


def try_parse_into_dtype(input: Any) -> PolarsDataType | None:
"""Try parsing an input into a Polars data type, returning None on failure."""
try:
return parse_into_dtype(input)
except TypeError:
return None


@functools.lru_cache(16)
def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsDataType:
"""Convert Python data type to Polars data type."""
if input is int:
return Int64()
elif input is float:
return Float64()
elif input is str:
return String()
elif input is bool:
return Boolean()
elif input is date:
return Date()
elif input is datetime:
return Datetime("us")
elif input is timedelta:
return Duration
elif input is time:
return Time()
elif input is PyDecimal:
return Decimal
elif input is bytes:
return Binary()
elif input is object:
return Object()
elif input is NoneType:
return Null()
elif input is list or input is tuple:
return List

elif hasattr(input, "__origin__") and hasattr(input, "__args__"):
return _parse_generic_into_dtype(input)

else:
_raise_on_invalid_dtype(input)


def _parse_generic_into_dtype(input: Any) -> PolarsDataType:
"""Parse a generic type into a Polars data type."""
base_type = input.__origin__
if base_type not in (tuple, list):
_raise_on_invalid_dtype(input)

inner_types = input.__args__
inner_type = inner_types[0]
if len(inner_types) > 1:
all_equal = all(t in (inner_type, ...) for t in inner_types)
if not all_equal:
_raise_on_invalid_dtype(input)

inner_type = inner_types[0]
inner_dtype = parse_py_type_into_dtype(inner_type)
return List(inner_dtype)


PY_TYPE_STR_TO_DTYPE: SchemaDict = {
"int": Int64(),
"float": Float64(),
"bool": Boolean(),
"str": String(),
"bytes": Binary(),
"date": Date(),
"time": Time(),
"datetime": Datetime("us"),
"object": Object(),
"NoneType": Null(),
"timedelta": Duration,
"Decimal": Decimal,
"list": List,
"tuple": List,
}


def _parse_forward_ref_into_dtype(input: ForwardRef) -> PolarsDataType:
"""Parse a ForwardRef into a Polars data type."""
annotation = input.__forward_arg__

# Strip "optional" designation - Polars data types are always nullable
formatted = re.sub(r"(^None \|)|(\| None$)", "", annotation).strip()

try:
return PY_TYPE_STR_TO_DTYPE[formatted]
except KeyError:
_raise_on_invalid_dtype(input)


def _parse_union_type_into_dtype(input: Any) -> PolarsDataType:
"""
Parse a union of types into a Polars data type.
Unions of multiple non-null types (e.g. `int | float`) are not supported.
Parameters
----------
input
A union type, e.g. `str | None` (new syntax) or `Union[str, None]` (old syntax).
"""
# Strip "optional" designation - Polars data types are always nullable
inner_types = [tp for tp in get_args(input) if tp is not NoneType]

if len(inner_types) != 1:
_raise_on_invalid_dtype(input)

input = inner_types[0]
return parse_into_dtype(input)


def _raise_on_invalid_dtype(input: Any) -> NoReturn:
"""Raise an informative error if the input could not be parsed."""
msg = f"cannot parse input of type {type(input).__name__!r} into Polars data type: {input!r}"
raise TypeError(msg) from None
6 changes: 3 additions & 3 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ class List(NestedType):
inner: PolarsDataType

def __init__(self, inner: PolarsDataType | PythonDataType):
self.inner = polars.datatypes.py_type_to_dtype(inner)
self.inner = polars.datatypes.parse_into_dtype(inner)

def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
# This equality check allows comparison of type classes and type instances.
Expand Down Expand Up @@ -675,7 +675,7 @@ def __init__(
msg = "Array constructor is missing the required argument `shape`"
raise TypeError(msg)

inner_parsed = polars.datatypes.py_type_to_dtype(inner)
inner_parsed = polars.datatypes.parse_into_dtype(inner)
inner_shape = inner_parsed.shape if isinstance(inner_parsed, Array) else ()

if isinstance(shape, int):
Expand Down Expand Up @@ -754,7 +754,7 @@ class Field:

def __init__(self, name: str, dtype: PolarsDataType):
self.name = name
self.dtype = polars.datatypes.py_type_to_dtype(dtype)
self.dtype = polars.datatypes.parse_into_dtype(dtype)

def __eq__(self, other: Field) -> bool: # type: ignore[override]
return (self.name == other.name) & (self.dtype == other.dtype)
Expand Down
Loading

0 comments on commit fb01aa4

Please sign in to comment.