-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PERF] Dyn Compare + Probe Table (#2618)
* Enables Dyn Compare + Partial Probe Table creation `Fn(&dyn Array, &dyn Array, usize, usize) -> Ordering`
- Loading branch information
Showing
10 changed files
with
710 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
use num_traits::Float; | ||
use ord::total_cmp; | ||
|
||
use std::cmp::Ordering; | ||
|
||
use crate::datatypes::*; | ||
use crate::error::Error; | ||
use crate::offset::Offset; | ||
use crate::{array::*, types::NativeType}; | ||
|
||
/// Compare the values at two arbitrary indices in two arbitrary arrays. | ||
pub type DynArrayComparator = | ||
Box<dyn Fn(&dyn Array, &dyn Array, usize, usize) -> Ordering + Send + Sync>; | ||
|
||
#[inline] | ||
unsafe fn is_valid<A: Array>(arr: &A, i: usize) -> bool { | ||
// avoid dyn function hop by using generic | ||
arr.validity() | ||
.as_ref() | ||
.map(|x| x.get_bit_unchecked(i)) | ||
.unwrap_or(true) | ||
} | ||
|
||
#[inline] | ||
fn compare_with_nulls<A: Array, F: FnOnce() -> Ordering>( | ||
left: &A, | ||
right: &A, | ||
i: usize, | ||
j: usize, | ||
nulls_equal: bool, | ||
cmp: F, | ||
) -> Ordering { | ||
assert!(i < left.len()); | ||
assert!(j < right.len()); | ||
match (unsafe { is_valid(left, i) }, unsafe { is_valid(right, j) }) { | ||
(true, true) => cmp(), | ||
(false, true) => Ordering::Greater, | ||
(true, false) => Ordering::Less, | ||
(false, false) => { | ||
if nulls_equal { | ||
Ordering::Equal | ||
} else { | ||
Ordering::Less | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[allow(clippy::eq_op)] | ||
#[inline] | ||
fn cmp_float<F: Float>(l: &F, r: &F, nans_equal: bool) -> std::cmp::Ordering { | ||
match (l.is_nan(), r.is_nan()) { | ||
(false, false) => unsafe { l.partial_cmp(r).unwrap_unchecked() }, | ||
(true, true) => { | ||
if nans_equal { | ||
Ordering::Equal | ||
} else { | ||
Ordering::Less | ||
} | ||
} | ||
(true, false) => Ordering::Greater, | ||
(false, true) => Ordering::Less, | ||
} | ||
} | ||
|
||
fn compare_dyn_floats<T: NativeType + Float>( | ||
nulls_equal: bool, | ||
nans_equal: bool, | ||
) -> DynArrayComparator { | ||
Box::new(move |left, right, i, j| { | ||
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); | ||
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); | ||
compare_with_nulls(left, right, i, j, nulls_equal, || { | ||
cmp_float::<T>( | ||
&unsafe { left.value_unchecked(i) }, | ||
&unsafe { right.value_unchecked(j) }, | ||
nans_equal, | ||
) | ||
}) | ||
}) | ||
} | ||
|
||
fn compare_dyn_primitives<T: NativeType + Ord>(nulls_equal: bool) -> DynArrayComparator { | ||
Box::new(move |left, right, i, j| { | ||
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); | ||
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(); | ||
compare_with_nulls(left, right, i, j, nulls_equal, || { | ||
total_cmp(&unsafe { left.value_unchecked(i) }, &unsafe { | ||
right.value_unchecked(j) | ||
}) | ||
}) | ||
}) | ||
} | ||
|
||
fn compare_dyn_string<O: Offset>(nulls_equal: bool) -> DynArrayComparator { | ||
Box::new(move |left, right, i, j| { | ||
let left = left.as_any().downcast_ref::<Utf8Array<O>>().unwrap(); | ||
let right = right.as_any().downcast_ref::<Utf8Array<O>>().unwrap(); | ||
compare_with_nulls(left, right, i, j, nulls_equal, || { | ||
unsafe { left.value_unchecked(i) }.cmp(unsafe { right.value_unchecked(j) }) | ||
}) | ||
}) | ||
} | ||
|
||
fn compare_dyn_binary<O: Offset>(nulls_equal: bool) -> DynArrayComparator { | ||
Box::new(move |left, right, i, j| { | ||
let left = left.as_any().downcast_ref::<BinaryArray<O>>().unwrap(); | ||
let right = right.as_any().downcast_ref::<BinaryArray<O>>().unwrap(); | ||
compare_with_nulls(left, right, i, j, nulls_equal, || { | ||
unsafe { left.value_unchecked(i) }.cmp(unsafe { right.value_unchecked(j) }) | ||
}) | ||
}) | ||
} | ||
|
||
fn compare_dyn_boolean(nulls_equal: bool) -> DynArrayComparator { | ||
Box::new(move |left, right, i, j| { | ||
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap(); | ||
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap(); | ||
compare_with_nulls(left, right, i, j, nulls_equal, || { | ||
unsafe { left.value_unchecked(i) }.cmp(unsafe { &right.value_unchecked(j) }) | ||
}) | ||
}) | ||
} | ||
|
||
pub fn build_dyn_array_compare( | ||
left: &DataType, | ||
right: &DataType, | ||
nulls_equal: bool, | ||
nans_equal: bool, | ||
) -> Result<DynArrayComparator> { | ||
use DataType::*; | ||
use IntervalUnit::*; | ||
use TimeUnit::*; | ||
Ok(match (left, right) { | ||
(a, b) if a != b => { | ||
return Err(Error::InvalidArgumentError( | ||
"Can't compare arrays of different types".to_string(), | ||
)); | ||
} | ||
(Boolean, Boolean) => compare_dyn_boolean(nulls_equal), | ||
(UInt8, UInt8) => compare_dyn_primitives::<u8>(nulls_equal), | ||
(UInt16, UInt16) => compare_dyn_primitives::<u16>(nulls_equal), | ||
(UInt32, UInt32) => compare_dyn_primitives::<u32>(nulls_equal), | ||
(UInt64, UInt64) => compare_dyn_primitives::<u64>(nulls_equal), | ||
(Int8, Int8) => compare_dyn_primitives::<i8>(nulls_equal), | ||
(Int16, Int16) => compare_dyn_primitives::<i16>(nulls_equal), | ||
(Int32, Int32) | ||
| (Date32, Date32) | ||
| (Time32(Second), Time32(Second)) | ||
| (Time32(Millisecond), Time32(Millisecond)) | ||
| (Interval(YearMonth), Interval(YearMonth)) => compare_dyn_primitives::<i32>(nulls_equal), | ||
(Int64, Int64) | ||
| (Date64, Date64) | ||
| (Time64(Microsecond), Time64(Microsecond)) | ||
| (Time64(Nanosecond), Time64(Nanosecond)) | ||
| (Timestamp(Second, None), Timestamp(Second, None)) | ||
| (Timestamp(Millisecond, None), Timestamp(Millisecond, None)) | ||
| (Timestamp(Microsecond, None), Timestamp(Microsecond, None)) | ||
| (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None)) | ||
| (Duration(Second), Duration(Second)) | ||
| (Duration(Millisecond), Duration(Millisecond)) | ||
| (Duration(Microsecond), Duration(Microsecond)) | ||
| (Duration(Nanosecond), Duration(Nanosecond)) => { | ||
compare_dyn_primitives::<i64>(nulls_equal) | ||
} | ||
(Float32, Float32) => compare_dyn_floats::<f32>(nulls_equal, nans_equal), | ||
(Float64, Float64) => compare_dyn_floats::<f64>(nulls_equal, nans_equal), | ||
(Decimal(_, _), Decimal(_, _)) => compare_dyn_primitives::<i128>(nulls_equal), | ||
(Utf8, Utf8) => compare_dyn_string::<i32>(nulls_equal), | ||
(LargeUtf8, LargeUtf8) => compare_dyn_string::<i64>(nulls_equal), | ||
(Binary, Binary) => compare_dyn_binary::<i32>(nulls_equal), | ||
(LargeBinary, LargeBinary) => compare_dyn_binary::<i64>(nulls_equal), | ||
// (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => { | ||
// match (key_type_lhs, key_type_rhs) { | ||
// (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), | ||
// (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), | ||
// (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right), | ||
// (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right), | ||
// (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right), | ||
// (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right), | ||
// (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right), | ||
// (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right), | ||
// (lhs, _) => { | ||
// return Err(Error::InvalidArgumentError(format!( | ||
// "Dictionaries do not support keys of type {lhs:?}" | ||
// ))) | ||
// } | ||
// } | ||
// } | ||
(lhs, _) => { | ||
return Err(Error::InvalidArgumentError(format!( | ||
"The data type type {lhs:?} has no natural order" | ||
))) | ||
} | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -705,6 +705,7 @@ mod fmt; | |
pub mod indexable; | ||
mod iterator; | ||
|
||
pub mod dyn_ord; | ||
pub mod growable; | ||
pub mod ord; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
use std::cmp::Ordering; | ||
|
||
use crate::{schema::Schema, DataType}; | ||
|
||
use arrow2::array::Array; | ||
use common_error::DaftError; | ||
use common_error::DaftResult; | ||
|
||
use arrow2::array::dyn_ord::build_dyn_array_compare; | ||
use arrow2::array::dyn_ord::DynArrayComparator; | ||
|
||
pub type MultiDynArrayComparator = | ||
Box<dyn Fn(&[Box<dyn Array>], &[Box<dyn Array>], usize, usize) -> Ordering + Send + Sync>; | ||
|
||
pub fn build_dyn_compare( | ||
left: &DataType, | ||
right: &DataType, | ||
nulls_equal: bool, | ||
nans_equal: bool, | ||
) -> DaftResult<DynArrayComparator> { | ||
if left != right { | ||
Err(DaftError::TypeError(format!( | ||
"Types do not match when creating comparator {} vs {}", | ||
left, right | ||
))) | ||
} else { | ||
Ok(build_dyn_array_compare( | ||
&left.to_physical().to_arrow()?, | ||
&right.to_physical().to_arrow()?, | ||
nulls_equal, | ||
nans_equal, | ||
)?) | ||
} | ||
} | ||
|
||
pub fn build_dyn_multi_array_compare( | ||
schema: &Schema, | ||
nulls_equal: bool, | ||
nans_equal: bool, | ||
) -> DaftResult<MultiDynArrayComparator> { | ||
let mut fn_list = Vec::with_capacity(schema.len()); | ||
for field in schema.fields.values() { | ||
fn_list.push(build_dyn_compare( | ||
&field.dtype, | ||
&field.dtype, | ||
nulls_equal, | ||
nans_equal, | ||
)?); | ||
} | ||
let combined_fn = Box::new( | ||
move |left: &[Box<dyn Array>], right: &[Box<dyn Array>], i: usize, j: usize| -> Ordering { | ||
for (f, (l, r)) in fn_list.iter().zip(left.iter().zip(right.iter())) { | ||
match f(l.as_ref(), r.as_ref(), i, j) { | ||
std::cmp::Ordering::Equal => continue, | ||
other => return other, | ||
} | ||
} | ||
Ordering::Equal | ||
}, | ||
); | ||
|
||
Ok(combined_fn) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub mod arrow; | ||
pub mod display_table; | ||
pub mod dyn_compare; | ||
pub mod hashable_float_wrapper; | ||
pub mod supertype; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
use common_error::{DaftError, DaftResult}; | ||
use daft_core::{ | ||
array::growable::{make_growable, Growable}, | ||
Series, | ||
}; | ||
|
||
use crate::Table; | ||
|
||
pub struct GrowableTable<'a> { | ||
growables: Vec<Box<dyn Growable + 'a>>, | ||
} | ||
|
||
impl<'a> GrowableTable<'a> { | ||
pub fn new(tables: &[&'a Table], use_validity: bool, capacity: usize) -> DaftResult<Self> { | ||
let num_tables = tables.len(); | ||
if tables.is_empty() { | ||
return Err(DaftError::ValueError( | ||
"Need at least 1 Table for GrowableTable".to_string(), | ||
)); | ||
} | ||
|
||
let first_table = tables.first().unwrap(); | ||
let num_columns = first_table.num_columns(); | ||
let first_schema = first_table.schema.as_ref(); | ||
|
||
let mut series_list = (0..num_columns) | ||
.map(|_| Vec::<&Series>::with_capacity(num_tables)) | ||
.collect::<Vec<_>>(); | ||
|
||
for tab in tables { | ||
if tab.schema.as_ref() != first_schema { | ||
return Err(DaftError::SchemaMismatch(format!( | ||
"GrowableTable requires all schemas to match, {} vs {}", | ||
first_schema, tab.schema | ||
))); | ||
} | ||
for (col, v) in tab.columns.iter().zip(series_list.iter_mut()) { | ||
v.push(col); | ||
} | ||
} | ||
let growables = series_list | ||
.into_iter() | ||
.zip(first_schema.fields.values()) | ||
.map(|(vector, f)| make_growable(&f.name, &f.dtype, vector, use_validity, capacity)) | ||
.collect::<Vec<_>>(); | ||
Ok(Self { growables }) | ||
} | ||
|
||
/// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. | ||
pub fn extend(&mut self, index: usize, start: usize, len: usize) { | ||
if !self.growables.is_empty() { | ||
self.growables | ||
.iter_mut() | ||
.for_each(|g| g.extend(index, start, len)) | ||
} | ||
} | ||
|
||
/// Extends this [`Growable`] with null elements | ||
pub fn add_nulls(&mut self, additional: usize) { | ||
if !self.growables.is_empty() { | ||
self.growables | ||
.iter_mut() | ||
.for_each(|g| g.add_nulls(additional)) | ||
} | ||
} | ||
|
||
/// Builds an array from the [`Growable`] | ||
pub fn build(&mut self) -> DaftResult<Table> { | ||
if self.growables.is_empty() { | ||
Table::empty(None) | ||
} else { | ||
let columns = self | ||
.growables | ||
.iter_mut() | ||
.map(|g| g.build()) | ||
.collect::<DaftResult<Vec<_>>>()?; | ||
Table::from_nonempty_columns(columns) | ||
} | ||
} | ||
} |
Oops, something went wrong.