From 73bdd36e5b31141d4706d01f441999c9d9d0e76a Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 10 May 2024 17:13:59 +0200 Subject: [PATCH] refactor(python): Return correct temporal type from Rust in `to_numpy` (#14353) --- py-polars/polars/series/series.py | 17 +----- py-polars/src/series/export.rs | 53 +++++++++++++++---- .../interop/numpy/test_to_numpy_series.py | 1 + 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 17b7ad7fd3279..ebce4249dbeb4 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4377,17 +4377,6 @@ def raise_on_copy() -> None: msg = "cannot return a zero-copy array" raise ValueError(msg) - def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: - if dtype == Date: - return np.dtype("datetime64[D]") - elif dtype == Duration: - return np.dtype(f"timedelta64[{dtype.time_unit}]") # type: ignore[union-attr] - elif dtype == Datetime: - return np.dtype(f"datetime64[{dtype.time_unit}]") # type: ignore[union-attr] - else: - msg = f"invalid temporal type: {dtype}" - raise TypeError(msg) - if self.n_chunks() > 1: raise_on_copy() self = self.rechunk() @@ -4421,9 +4410,8 @@ def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: np_array = s_u8._s.to_numpy_view().view(bool) elif dtype == Date: raise_on_copy() - np_dtype = temporal_dtype_to_numpy(dtype) s_i32 = self.to_physical() - np_array = s_i32._s.to_numpy_view().astype(np_dtype) + np_array = s_i32._s.to_numpy_view().astype(" Any: else: raise_on_copy() np_array = self._s.to_numpy() - if dtype in (Datetime, Duration, Date): - np_dtype = temporal_dtype_to_numpy(dtype) - np_array = np_array.view(np_dtype) if writable and not np_array.flags.writeable: raise_on_copy() diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index 38e4fae3ecb29..1097e823f0f0d 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -185,7 +185,34 @@ impl PySeries { np_arr.into_py(py) }, Date => date_series_to_numpy(py, s), - Datetime(_, _) | Duration(_) => temporal_series_to_numpy(py, s), + Datetime(tu, _) => { + use numpy::datetime::{units, Datetime}; + match tu { + TimeUnit::Milliseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Microseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Nanoseconds => { + temporal_series_to_numpy::>(py, s) + }, + } + }, + Duration(tu) => { + use numpy::datetime::{units, Timedelta}; + match tu { + TimeUnit::Milliseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Microseconds => { + temporal_series_to_numpy::>(py, s) + }, + TimeUnit::Nanoseconds => { + temporal_series_to_numpy::>(py, s) + }, + } + }, Time => { let ca = s.time().unwrap(); let iter = time_to_pyobject_iter(py, ca); @@ -254,19 +281,27 @@ where } /// Convert dates directly to i64 with i64::MIN representing a null value fn date_series_to_numpy(py: Python, s: &Series) -> PyObject { + use numpy::datetime::{units, Datetime}; + let s_phys = s.to_physical_repr(); let ca = s_phys.i32().unwrap(); - let mapper = |opt_v: Option| match opt_v { - Some(v) => v as i64, - None => i64::MIN, + let mapper = |opt_v: Option| { + let int = match opt_v { + Some(v) => v as i64, + None => i64::MIN, + }; + int.into() }; - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(mapper)); - np_arr.into_py(py) + let iter = ca.iter().map(mapper); + PyArray1::>::from_iter_bound(py, iter).into_py(py) } /// Convert datetimes and durations with i64::MIN representing a null value -fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject { +fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject +where + T: From + numpy::Element, +{ let s_phys = s.to_physical_repr(); let ca = s_phys.i64().unwrap(); - let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|v| v.unwrap_or(i64::MIN))); - np_arr.into_py(py) + let iter = ca.iter().map(|v| v.unwrap_or(i64::MIN).into()); + PyArray1::::from_iter_bound(py, iter).into_py(py) } diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py index fe2909672fa8f..5674d183e3f66 100644 --- a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -130,6 +130,7 @@ def test_series_to_numpy_date() -> None: assert s.to_list() == result.tolist() assert result.dtype == np.dtype("datetime64[D]") + assert result.flags.writeable is True assert_allow_copy_false_raises(s)