Skip to content

Commit

Permalink
refactor(rust!): Refactor AnyValue supertype logic (pola-rs#15280)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Mar 25, 2024
1 parent 8f616b8 commit 705b148
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 48 deletions.
26 changes: 4 additions & 22 deletions crates/polars-core/src/frame/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use av_buffer::*;
use rayon::prelude::*;

use crate::prelude::*;
use crate::utils::try_get_supertype;
use crate::utils::{dtypes_to_supertype, try_get_supertype};
use crate::POOL;

#[derive(Debug, Clone, PartialEq, Eq, Default)]
Expand Down Expand Up @@ -83,33 +83,15 @@ pub fn coerce_data_type<A: Borrow<DataType>>(datatypes: &[A]) -> DataType {
try_get_supertype(lhs, rhs).unwrap_or(String)
}

pub fn any_values_to_dtype(column: &[AnyValue]) -> PolarsResult<(DataType, usize)> {
// we need an index-map as the order of dtypes influences how the
// struct fields are constructed.
let mut types_set = PlIndexSet::new();
for val in column.iter() {
types_set.insert(val.into());
}
let n_types = types_set.len();
Ok((types_set_to_dtype(types_set)?, n_types))
}

fn types_set_to_dtype(types_set: PlIndexSet<DataType>) -> PolarsResult<DataType> {
types_set
.into_iter()
.map(Ok)
.reduce(|a, b| try_get_supertype(&a?, &b?))
.unwrap()
}

/// Infer schema from rows and set the supertypes of the columns as column data type.
pub fn rows_to_schema_supertypes(
rows: &[Row],
infer_schema_length: Option<usize>,
) -> PolarsResult<Schema> {
polars_ensure!(!rows.is_empty(), NoData: "no rows, cannot infer schema");

// no of rows to use to infer dtype
let max_infer = infer_schema_length.unwrap_or(rows.len());
polars_ensure!(!rows.is_empty(), NoData: "no rows, cannot infer schema");
let mut dtypes: Vec<PlIndexSet<DataType>> = vec![PlIndexSet::new(); rows[0].0.len()];

for row in rows.iter().take(max_infer) {
Expand All @@ -125,7 +107,7 @@ pub fn rows_to_schema_supertypes(
let dtype = if types_set.is_empty() {
DataType::Unknown
} else {
types_set_to_dtype(types_set)?
dtypes_to_supertype(&types_set)?
};
Ok(Field::new(format!("column_{i}").as_ref(), dtype))
})
Expand Down
22 changes: 2 additions & 20 deletions crates/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Write;
#[cfg(feature = "object")]
use crate::chunked_array::object::registry::ObjectRegistry;
use crate::prelude::*;
use crate::utils::try_get_supertype;
use crate::utils::any_values_to_supertype;

impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom<T, [AnyValue<'a>]> for Series {
fn new(name: &str, v: T) -> Self {
Expand Down Expand Up @@ -47,28 +47,10 @@ impl Series {
},
}
}
fn get_any_values_supertype(values: &[AnyValue]) -> PolarsResult<DataType> {
let mut supertype = DataType::Null;
let mut dtypes = PlHashSet::<DataType>::new();
for av in values {
if dtypes.insert(av.dtype()) {
supertype = try_get_supertype(&supertype, &av.dtype()).map_err(|_| {
polars_err!(
SchemaMismatch:
"failed to infer supertype of values; partial supertype is {:?}, found value of type {:?}: {}",
supertype, av.dtype(), av
)
}
)?;
}
}
Ok(supertype)
}

let dtype = if strict {
get_first_non_null_dtype(values)
} else {
get_any_values_supertype(values)?
any_values_to_supertype(values)?
};
Self::from_any_values_and_dtype(name, values, &dtype, strict)
}
Expand Down
37 changes: 37 additions & 0 deletions crates/polars-core/src/utils/any_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use crate::prelude::*;
use crate::utils::dtypes_to_supertype;

/// Determine the supertype of a collection of [`AnyValue`].
///
/// [`AnyValue`]: crate::datatypes::AnyValue
pub fn any_values_to_supertype<'a, I>(values: I) -> PolarsResult<DataType>
where
I: IntoIterator<Item = &'a AnyValue<'a>>,
{
let dtypes = any_values_to_dtype_set(values);
dtypes_to_supertype(&dtypes)
}

/// Determine the supertype and the number of unique data types of a collection of [`AnyValue`].
///
/// [`AnyValue`]: crate::datatypes::AnyValue
pub fn any_values_to_supertype_and_n_dtypes<'a, I>(values: I) -> PolarsResult<(DataType, usize)>
where
I: IntoIterator<Item = &'a AnyValue<'a>>,
{
let dtypes = any_values_to_dtype_set(values);
let supertype = dtypes_to_supertype(&dtypes)?;
let n_dtypes = dtypes.len();
Ok((supertype, n_dtypes))
}

/// Extract the ordered set of data types from a collection of AnyValues
///
/// Retaining the order is important if the set is used to determine a supertype,
/// as this can influence how Struct fields are constructed.
fn any_values_to_dtype_set<'a, I>(values: I) -> PlIndexSet<DataType>
where
I: IntoIterator<Item = &'a AnyValue<'a>>,
{
values.into_iter().map(|av| av.into()).collect()
}
2 changes: 2 additions & 0 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod any_value;
pub mod flatten;
pub(crate) mod series;
mod supertype;
use std::borrow::Cow;
use std::ops::{Deref, DerefMut};

pub use any_value::*;
use arrow::bitmap::bitmask::BitMask;
use arrow::bitmap::Bitmap;
pub use arrow::legacy::utils::*;
Expand Down
22 changes: 20 additions & 2 deletions crates/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use super::*;

/// Given two datatypes, determine the supertype that both types can safely be cast to
/// Given two data types, determine the data type that both types can safely be cast to.
///
/// Returns a [`PolarsError::ComputeError`] if no such data type exists.
pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
get_supertype(l, r).ok_or_else(
|| polars_err!(ComputeError: "failed to determine supertype of {} and {}", l, r),
)
}

/// Given two datatypes, determine the supertype that both types can safely be cast to
/// Given two data types, determine the data type that both types can safely be cast to.
///
/// Returns [`None`] if no such data type exists.
pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
fn inner(l: &DataType, r: &DataType) -> Option<DataType> {
use DataType::*;
Expand Down Expand Up @@ -278,6 +282,20 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
inner(l, r).or_else(|| inner(r, l))
}

/// Given multiple data types, determine the data type that all types can safely be cast to.
///
/// Returns [`DataType::Null`] if no data types were passed.
pub fn dtypes_to_supertype<'a, I>(dtypes: I) -> PolarsResult<DataType>
where
I: IntoIterator<Item = &'a DataType>,
{
dtypes
.into_iter()
.try_fold(DataType::Null, |supertype, dtype| {
try_get_supertype(&supertype, dtype)
})
}

#[cfg(feature = "dtype-struct")]
fn union_struct_fields(fields_a: &[Field], fields_b: &[Field]) -> Option<DataType> {
let (longest, shortest) = {
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/conversion/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use polars::chunked_array::object::PolarsObjectSafe;
use polars::datatypes::{DataType, Field, OwnedObject, PlHashMap, TimeUnit};
use polars::prelude::{AnyValue, Series};
use polars_core::frame::row::any_values_to_dtype;
use polars_core::utils::any_values_to_supertype_and_n_dtypes;
use pyo3::exceptions::{PyOverflowError, PyTypeError};
use pyo3::intern;
use pyo3::prelude::*;
Expand Down Expand Up @@ -282,11 +282,11 @@ pub(crate) fn py_object_to_any_value(ob: &PyAny, strict: bool) -> PyResult<AnyVa
avs.push(av)
}

let (dtype, n_types) =
any_values_to_dtype(&avs).map_err(|e| PyTypeError::new_err(e.to_string()))?;
let (dtype, n_dtypes) = any_values_to_supertype_and_n_dtypes(&avs)
.map_err(|e| PyTypeError::new_err(e.to_string()))?;

// This path is only taken if there is no question about the data type.
if dtype.is_primitive() && n_types == 1 {
if dtype.is_primitive() && n_dtypes == 1 {
get_list_with_constructor(ob)
} else {
// Push the rest.
Expand Down

0 comments on commit 705b148

Please sign in to comment.