Skip to content

Commit

Permalink
refactor(python): Return correct temporal type from Rust in to_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed May 10, 2024
1 parent e94daa6 commit 1515087
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 25 deletions.
17 changes: 1 addition & 16 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4377,17 +4377,6 @@ def raise_on_copy() -> None:
msg = "cannot return a zero-copy array"
raise ValueError(msg)

def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any:
if dtype == Date:
return np.dtype("datetime64[D]")
elif dtype == Duration:
return np.dtype(f"timedelta64[{dtype.time_unit}]") # type: ignore[union-attr]
elif dtype == Datetime:
return np.dtype(f"datetime64[{dtype.time_unit}]") # type: ignore[union-attr]
else:
msg = f"invalid temporal type: {dtype}"
raise TypeError(msg)

if self.n_chunks() > 1:
raise_on_copy()
self = self.rechunk()
Expand Down Expand Up @@ -4421,19 +4410,15 @@ def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any:
np_array = s_u8._s.to_numpy_view().view(bool)
elif dtype == Date:
raise_on_copy()
np_dtype = temporal_dtype_to_numpy(dtype)
s_i32 = self.to_physical()
np_array = s_i32._s.to_numpy_view().astype(np_dtype)
np_array = s_i32._s.to_numpy_view().astype("<M8[D]")
else:
raise_on_copy()
np_array = self._s.to_numpy()

else:
raise_on_copy()
np_array = self._s.to_numpy()
if dtype in (Datetime, Duration, Date):
np_dtype = temporal_dtype_to_numpy(dtype)
np_array = np_array.view(np_dtype)

if writable and not np_array.flags.writeable:
raise_on_copy()
Expand Down
53 changes: 44 additions & 9 deletions py-polars/src/series/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,34 @@ impl PySeries {
np_arr.into_py(py)
},
Date => date_series_to_numpy(py, s),
Datetime(_, _) | Duration(_) => temporal_series_to_numpy(py, s),
Datetime(tu, _) => {
use numpy::datetime::{units, Datetime};
match tu {
TimeUnit::Milliseconds => {
temporal_series_to_numpy::<Datetime<units::Milliseconds>>(py, s)
},
TimeUnit::Microseconds => {
temporal_series_to_numpy::<Datetime<units::Microseconds>>(py, s)
},
TimeUnit::Nanoseconds => {
temporal_series_to_numpy::<Datetime<units::Nanoseconds>>(py, s)
},
}
},
Duration(tu) => {
use numpy::datetime::{units, Timedelta};
match tu {
TimeUnit::Milliseconds => {
temporal_series_to_numpy::<Timedelta<units::Milliseconds>>(py, s)
},
TimeUnit::Microseconds => {
temporal_series_to_numpy::<Timedelta<units::Microseconds>>(py, s)
},
TimeUnit::Nanoseconds => {
temporal_series_to_numpy::<Timedelta<units::Nanoseconds>>(py, s)
},
}
},
Time => {
let ca = s.time().unwrap();
let iter = time_to_pyobject_iter(py, ca);
Expand Down Expand Up @@ -254,19 +281,27 @@ where
}
/// Convert dates directly to i64 with i64::MIN representing a null value
fn date_series_to_numpy(py: Python, s: &Series) -> PyObject {
use numpy::datetime::{units, Datetime};

let s_phys = s.to_physical_repr();
let ca = s_phys.i32().unwrap();
let mapper = |opt_v: Option<i32>| match opt_v {
Some(v) => v as i64,
None => i64::MIN,
let mapper = |opt_v: Option<i32>| {
let int = match opt_v {
Some(v) => v as i64,
None => i64::MIN,
};
int.into()
};
let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(mapper));
np_arr.into_py(py)
let iter = ca.iter().map(mapper);
PyArray1::<Datetime<units::Days>>::from_iter_bound(py, iter).into_py(py)
}
/// Convert datetimes and durations with i64::MIN representing a null value
fn temporal_series_to_numpy(py: Python, s: &Series) -> PyObject {
fn temporal_series_to_numpy<T>(py: Python, s: &Series) -> PyObject
where
T: From<i64> + numpy::Element,
{
let s_phys = s.to_physical_repr();
let ca = s_phys.i64().unwrap();
let np_arr = PyArray1::from_iter_bound(py, ca.iter().map(|v| v.unwrap_or(i64::MIN)));
np_arr.into_py(py)
let iter = ca.iter().map(|v| v.unwrap_or(i64::MIN).into());
PyArray1::<T>::from_iter_bound(py, iter).into_py(py)
}
1 change: 1 addition & 0 deletions py-polars/tests/unit/interop/numpy/test_to_numpy_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_series_to_numpy_date() -> None:

assert s.to_list() == result.tolist()
assert result.dtype == np.dtype("datetime64[D]")
assert result.flags.writeable is True
assert_allow_copy_false_raises(s)


Expand Down

0 comments on commit 1515087

Please sign in to comment.