Skip to content

Commit

Permalink
feat: NDarray/Tensor support (#16466)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored May 25, 2024
1 parent bebdc36 commit 019dfe8
Show file tree
Hide file tree
Showing 29 changed files with 367 additions and 96 deletions.
43 changes: 42 additions & 1 deletion crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::collections::BTreeMap;

#[cfg(feature = "dtype-array")]
use polars_utils::format_tuple;

use super::*;
#[cfg(feature = "object")]
use crate::chunked_array::object::registry::ObjectRegistry;
Expand Down Expand Up @@ -174,6 +177,25 @@ impl DataType {
}
}

#[cfg(feature = "dtype-array")]
/// Get the full shape of a multidimensional array.
pub fn get_shape(&self) -> Option<Vec<usize>> {
fn get_shape_impl(dt: &DataType, shape: &mut Vec<usize>) {
if let DataType::Array(inner, size) = dt {
shape.push(*size);
get_shape_impl(inner, shape);
}
}

if let DataType::Array(inner, size) = self {
let mut shape = vec![*size];
get_shape_impl(inner, &mut shape);
Some(shape)
} else {
None
}
}

/// Get the inner data type of a nested type.
pub fn inner_dtype(&self) -> Option<&DataType> {
match self {
Expand All @@ -184,6 +206,15 @@ impl DataType {
}
}

/// Get the absolute inner data type of a nested type.
pub fn leaf_dtype(&self) -> &DataType {
let mut prev = self;
while let Some(dtype) = prev.inner_dtype() {
prev = dtype
}
prev
}

/// Convert to the physical data type
#[must_use]
pub fn to_physical(&self) -> DataType {
Expand Down Expand Up @@ -646,7 +677,17 @@ impl Display for DataType {
DataType::Duration(tu) => return write!(f, "duration[{tu}]"),
DataType::Time => "time",
#[cfg(feature = "dtype-array")]
DataType::Array(tp, size) => return write!(f, "array[{tp}, {size}]"),
DataType::Array(_, _) => {
let tp = self.leaf_dtype();

let dims = self.get_shape().unwrap();
let shape = if dims.len() == 1 {
format!("{}", dims[0])
} else {
format_tuple!(dims)
};
return write!(f, "array[{tp}, {}]", shape);
},
DataType::List(tp) => return write!(f, "list[{tp}]"),
#[cfg(feature = "object")]
DataType::Object(s, _) => s,
Expand Down
1 change: 0 additions & 1 deletion crates/polars-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod downcast;
mod extend;
mod null;
mod to_list;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-expr/src/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl PhysicalExpr for AggregationExpr {
let s = match ac.agg_state() {
// mean agg:
// -> f64 -> list<f64>
AggState::AggregatedScalar(s) => s.reshape(&[-1, 1]).unwrap(),
AggState::AggregatedScalar(s) => s.reshape_list(&[-1, 1]).unwrap(),
_ => {
let agg = ac.aggregated();
agg.as_list().into_series()
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub(crate) use filter::*;
pub(crate) use literal::*;
use polars_core::prelude::*;
use polars_io::predicates::PhysicalIoExpr;
use polars_ops::prelude::*;
use polars_plan::prelude::*;
#[cfg(feature = "dynamic_group_by")]
pub(crate) use rolling::RollingExpr;
Expand Down Expand Up @@ -419,7 +420,7 @@ impl<'a> AggregationContext<'a> {
self.groups();
let rows = self.groups.len();
let s = s.new_from_index(0, rows);
s.reshape(&[rows as i64, -1]).unwrap()
s.reshape_list(&[rows as i64, -1]).unwrap()
},
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::_split_offsets;
use polars_core::POOL;
use polars_ops::prelude::*;
use polars_plan::prelude::expr_ir::ExprIR;
use polars_plan::prelude::*;
use rayon::prelude::*;
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::chunked_array::list::sum_mean::sum_with_nulls;
#[cfg(feature = "diff")]
use crate::prelude::diff;
use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical};
use crate::series::ArgAgg;
use crate::series::{ArgAgg, SeriesReshape};

pub(super) fn has_inner_nulls(ca: &ListChunked) -> bool {
for arr in ca.downcast_iter() {
Expand All @@ -44,7 +44,7 @@ fn cast_rhs(
}
if !matches!(s.dtype(), DataType::List(_)) && s.dtype() == inner_type {
// coerce to list JIT
*s = s.reshape(&[-1, 1]).unwrap();
*s = s.reshape_list(&[-1, 1]).unwrap();
}
if s.dtype() != dtype {
*s = s.cast(dtype).map_err(|e| {
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod rank;
mod reinterpret;
#[cfg(feature = "replace")]
mod replace;
mod reshape;
#[cfg(feature = "rle")]
mod rle;
#[cfg(feature = "rolling_window")]
Expand Down Expand Up @@ -137,6 +138,7 @@ pub use unique::*;
pub use various::*;
mod not;
pub use not::*;
pub use reshape::*;

pub trait SeriesSealed {
fn as_series(&self) -> &Series;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
use std::borrow::Cow;
#[cfg(feature = "dtype-array")]
use std::collections::VecDeque;

use arrow::array::*;
use arrow::legacy::kernels::list::array_to_unit_list;
use arrow::offset::Offsets;
use polars_core::chunked_array::builder::get_list_builder;
use polars_core::datatypes::{DataType, ListChunked};
use polars_core::prelude::{IntoSeries, Series};
use polars_error::{polars_bail, polars_ensure, PolarsResult};
#[cfg(feature = "dtype-array")]
use polars_utils::format_tuple;

use crate::chunked_array::builder::get_list_builder;
use crate::prelude::*;

fn reshape_fast_path(name: &str, s: &Series) -> Series {
Expand All @@ -22,15 +30,45 @@ fn reshape_fast_path(name: &str, s: &Series) -> Series {
ca.into_series()
}

impl Series {
pub trait SeriesReshape: SeriesSealed {
/// Recurse nested types until we are at the leaf array.
fn get_leaf_array(&self) -> Series {
let s = self.as_series();
match s.dtype() {
#[cfg(feature = "dtype-array")]
DataType::Array(dtype, _) => {
let ca = s.array().unwrap();
let chunks = ca
.downcast_iter()
.map(|arr| arr.values().clone())
.collect::<Vec<_>>();
// Safety: guarded by the type system
unsafe { Series::from_chunks_and_dtype_unchecked(s.name(), chunks, dtype) }
.get_leaf_array()
},
DataType::List(dtype) => {
let ca = s.list().unwrap();
let chunks = ca
.downcast_iter()
.map(|arr| arr.values().clone())
.collect::<Vec<_>>();
// Safety: guarded by the type system
unsafe { Series::from_chunks_and_dtype_unchecked(s.name(), chunks, dtype) }
.get_leaf_array()
},
_ => s.clone(),
}
}

/// Convert the values of this Series to a ListChunked with a length of 1,
/// so a Series of `[1, 2, 3]` becomes `[[1, 2, 3]]`.
pub fn implode(&self) -> PolarsResult<ListChunked> {
let s = self.rechunk();
fn implode(&self) -> PolarsResult<ListChunked> {
let s = self.as_series();
let s = s.rechunk();
let values = s.array_ref(0);

let offsets = vec![0i64, values.len() as i64];
let inner_type = self.dtype();
let inner_type = s.dtype();

let data_type = ListArray::<i64>::default_datatype(values.data_type().clone());

Expand All @@ -44,20 +82,70 @@ impl Series {
)
};

let mut ca = ListChunked::with_chunk(self.name(), arr);
let mut ca = ListChunked::with_chunk(s.name(), arr);
unsafe { ca.to_logical(inner_type.clone()) };
ca.set_fast_explode();
Ok(ca)
}

pub fn reshape(&self, dimensions: &[i64]) -> PolarsResult<Series> {
#[cfg(feature = "dtype-array")]
fn reshape_array(&self, dimensions: &[i64]) -> PolarsResult<Series> {
let mut dims = dimensions.iter().copied().collect::<VecDeque<_>>();

let leaf_array = self.get_leaf_array();
let size = leaf_array.len() as i64;

// Infer dimension
if dims.contains(&-1) {
let infer_dims = dims.iter().filter(|d| **d == -1).count();
polars_ensure!(infer_dims == 1, InvalidOperation: "can only infer single dimension, found {}", infer_dims);

let mut prod = 1;
for &dim in &dims {
if dim != -1 {
prod *= dim;
}
}
polars_ensure!(size % prod == 0, InvalidOperation: "cannot reshape array of size {} into shape: {}", size, format_tuple!(dims));
let inferred_value = size / prod;
for dim in &mut dims {
if *dim == -1 {
*dim = inferred_value;
break;
}
}
}
let leaf_array = leaf_array.rechunk();
let mut prev_dtype = leaf_array.dtype().clone();
let mut prev_array = leaf_array.chunks()[0].clone();

// We pop the outer dimension as that is the height of the series.
let _ = dims.pop_front();
while let Some(dim) = dims.pop_back() {
prev_dtype = DataType::Array(Box::new(prev_dtype), dim as usize);

prev_array =
FixedSizeListArray::new(prev_dtype.to_arrow(true), prev_array, None).boxed();
}
Ok(unsafe {
Series::from_chunks_and_dtype_unchecked(
leaf_array.name(),
vec![prev_array],
&prev_dtype,
)
})
}

fn reshape_list(&self, dimensions: &[i64]) -> PolarsResult<Series> {
let s = self.as_series();

if dimensions.is_empty() {
polars_bail!(ComputeError: "reshape `dimensions` cannot be empty")
}
let s = if let DataType::List(_) = self.dtype() {
Cow::Owned(self.explode()?)
let s = if let DataType::List(_) = s.dtype() {
Cow::Owned(s.explode()?)
} else {
Cow::Borrowed(self)
Cow::Borrowed(s)
};

let s_ref = s.as_ref();
Expand All @@ -78,7 +166,7 @@ impl Series {

if s_ref.len() == 0_usize {
if (rows == -1 || rows == 0) && (cols == -1 || cols == 0) {
let s = reshape_fast_path(self.name(), s_ref);
let s = reshape_fast_path(s.name(), s_ref);
return Ok(s);
} else {
polars_bail!(ComputeError: "cannot reshape len 0 into shape {:?}", dimensions,)
Expand All @@ -97,7 +185,7 @@ impl Series {

// Fast path, we can create a unit list so we only allocate offsets.
if rows as usize == s_ref.len() && cols == 1 {
let s = reshape_fast_path(self.name(), s_ref);
let s = reshape_fast_path(s.name(), s_ref);
return Ok(s);
}

Expand All @@ -107,7 +195,7 @@ impl Series {
);

let mut builder =
get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, self.name())?;
get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name())?;

let mut offset = 0i64;
for _ in 0..rows {
Expand All @@ -118,14 +206,18 @@ impl Series {
Ok(builder.finish().into_series())
},
_ => {
panic!("more than two dimensions not yet supported");
polars_bail!(InvalidOperation: "more than two dimensions not supported in reshaping to List.\n\nConsider reshaping to Array type.");
},
}
}
}

impl SeriesReshape for Series {}

#[cfg(test)]
mod test {
use polars_core::prelude::*;

use super::*;

#[test]
Expand Down Expand Up @@ -153,7 +245,7 @@ mod test {
(&[-1, 2], 2),
(&[2, -1], 2),
] {
let out = s.reshape(dims)?;
let out = s.reshape_list(dims)?;
assert_eq!(out.len(), list_len);
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.explode()?.len(), 4);
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ pub(super) fn unique_counts(s: &Series) -> PolarsResult<Series> {
polars_ops::prelude::unique_counts(s)
}

pub(super) fn reshape(s: &Series, dimensions: Vec<i64>) -> PolarsResult<Series> {
s.reshape(&dimensions)
pub(super) fn reshape(s: &Series, dimensions: &[i64], nested: &NestedType) -> PolarsResult<Series> {
match nested {
NestedType::List => s.reshape_list(dimensions),
#[cfg(feature = "dtype-array")]
NestedType::Array => s.reshape_array(dimensions),
}
}

#[cfg(feature = "repeat_by")]
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ pub(super) fn concat(s: &mut [Series]) -> PolarsResult<Option<Series>> {
let mut first_ca = match first.list().ok() {
Some(ca) => ca,
None => {
first = first.reshape(&[-1, 1]).unwrap();
first = first.reshape_list(&[-1, 1]).unwrap();
first.list().unwrap()
},
}
Expand Down Expand Up @@ -482,7 +482,7 @@ pub(super) fn gather(args: &[Series], null_on_oob: bool) -> PolarsResult<Series>
let idx = idx.get(0)?.try_extract::<i64>()?;
let out = ca.lst_get(idx, null_on_oob)?;
// make sure we return a list
out.reshape(&[-1, 1])
out.reshape_list(&[-1, 1])
} else {
ca.lst_gather(idx, null_on_oob)
}
Expand Down
Loading

0 comments on commit 019dfe8

Please sign in to comment.