Skip to content

Commit

Permalink
Pass dtype as param
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed May 10, 2024
1 parent c4e8678 commit 9eeb687
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions py-polars/src/to_numpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::ffi::{c_int, c_void};

use ndarray::{Dim, Dimension, IntoDimension};
use numpy::npyffi::{flags, PyArrayObject};
use numpy::{npyffi, Element, IntoPyArray, PyArrayDescrMethods, ToNpyDims, PY_ARRAY_API};
use numpy::{
npyffi, Element, IntoPyArray, PyArrayDescr, PyArrayDescrMethods, ToNpyDims, PY_ARRAY_API,
};
use polars_core::prelude::*;
use polars_core::utils::try_get_supertype;
use polars_core::with_match_physical_numeric_polars_type;
Expand All @@ -12,8 +14,9 @@ use crate::conversion::Wrap;
use crate::dataframe::PyDataFrame;
use crate::series::PySeries;

pub(crate) unsafe fn create_borrowed_np_array<T: NumericNative + Element, I>(
pub(crate) unsafe fn create_borrowed_np_array<I>(
py: Python,
dtype: Bound<PyArrayDescr>,
mut shape: Dim<I>,
flags: c_int,
data: *mut c_void,
Expand All @@ -26,7 +29,7 @@ where
let array = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype_bound(py).into_dtype_ptr(),
dtype.into_dtype_ptr(),
shape.ndim_cint(),
shape.as_dims_ptr(),
// We don't provide strides, but provide flags that tell c/f-order
Expand Down Expand Up @@ -68,10 +71,13 @@ impl PySeries {
let owner = self.clone().into_py(py);
with_match_physical_numeric_polars_type!(self.series.dtype(), |$T| {
let ca: &ChunkedArray<$T> = self.series.unpack::<$T>().unwrap();
let dtype = <$T as PolarsNumericType>::Native::get_dtype_bound(py);
// let dtype = PyArrayDescr::new_bound(py, intern!(py, "datetime64[us]")).unwrap();
let slice = ca.data_views().next().unwrap();
let view = unsafe {
create_borrowed_np_array::<<$T as PolarsNumericType>::Native, _>(
create_borrowed_np_array::<_>(
py,
dtype,
dims,
flags::NPY_ARRAY_FARRAY_RO,
slice.as_ptr() as _,
Expand Down Expand Up @@ -109,12 +115,9 @@ impl PyDataFrame {
// Object to the dataframe keep the memory alive.
let owner = self.clone().into_py(py);

fn get_ptr<T: PolarsNumericType>(
py: Python,
columns: &[Series],
owner: PyObject,
) -> Option<PyObject>
fn get_ptr<T>(py: Python, columns: &[Series], owner: PyObject) -> Option<PyObject>
where
T: PolarsNumericType,
T::Native: Element,
{
let slices = columns
Expand All @@ -139,9 +142,11 @@ impl PyDataFrame {

if all_contiguous {
let start_ptr = first.as_ptr();
let dtype = T::Native::get_dtype_bound(py);
let dims = [first.len(), columns.len()].into_dimension();
Some(create_borrowed_np_array::<T::Native, _>(
Some(create_borrowed_np_array::<_>(
py,
dtype,
dims,
flags::NPY_ARRAY_FARRAY_RO,
start_ptr as _,
Expand Down

0 comments on commit 9eeb687

Please sign in to comment.