Skip to content

Commit

Permalink
fix(python): Fix dtype parameter in pandas_to_pyseries function
Browse files Browse the repository at this point in the history
  • Loading branch information
luke396 committed Apr 29, 2024
1 parent 2e28176 commit b12edeb
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
3 changes: 2 additions & 1 deletion py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,15 @@ def to_series_chunk(values: list[Any], dtype: PolarsDataType | None) -> Series:
def pandas_to_pyseries(
name: str,
values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex,
dtype: PolarsDataType | None = None,
*,
nan_to_null: bool = True,
) -> PySeries:
"""Construct a PySeries from a pandas Series or DatetimeIndex."""
if not name and values.name is not None:
name = str(values.name)
if is_simple_numpy_backed_pandas_series(values):
return pl.Series(name, values.to_numpy(), nan_to_null=nan_to_null)._s
return pl.Series(name, values.to_numpy(), dtype=dtype, nan_to_null=nan_to_null)._s
if not _PYARROW_AVAILABLE:
msg = (
"pyarrow is required for converting a pandas series to Polars, "
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(
elif _check_for_pandas(values) and isinstance(
values, (pd.Series, pd.Index, pd.DatetimeIndex)
):
self._s = pandas_to_pyseries(name, values)
self._s = pandas_to_pyseries(name, values, dtype=dtype)

elif _is_generator(values):
self._s = iterable_to_pyseries(name, values, dtype=dtype, strict=strict)
Expand Down
4 changes: 4 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,3 +2337,7 @@ def test_search_sorted(

multiple_s = s.search_sorted(multiple)
assert_series_equal(multiple_s, pl.Series(multiple_expected, dtype=pl.UInt32))

def test_series_from_pandas_with_dtype()->None:
s = pl.Series('foo', pd.Series([1,2,3]), pl.Float32)
assert_series_equal(s, pl.Series('foo', [1,2,3], dtype=pl.Float32))

0 comments on commit b12edeb

Please sign in to comment.