Skip to content

Commit

Permalink
[PERF] Dyn Compare + Probe Table (#2618)
Browse files Browse the repository at this point in the history
* Enables Dyn Compare + Partial Probe Table creation
`Fn(&dyn Array, &dyn Array, usize, usize) -> Ordering`
  • Loading branch information
samster25 authored Aug 5, 2024
1 parent 3b23e16 commit 230feb3
Show file tree
Hide file tree
Showing 10 changed files with 710 additions and 3 deletions.
196 changes: 196 additions & 0 deletions src/arrow2/src/array/dyn_ord.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use num_traits::Float;
use ord::total_cmp;

use std::cmp::Ordering;

use crate::datatypes::*;
use crate::error::Error;
use crate::offset::Offset;
use crate::{array::*, types::NativeType};

/// Compare the values at two arbitrary indices in two arbitrary arrays.
pub type DynArrayComparator =
Box<dyn Fn(&dyn Array, &dyn Array, usize, usize) -> Ordering + Send + Sync>;

#[inline]
unsafe fn is_valid<A: Array>(arr: &A, i: usize) -> bool {
// avoid dyn function hop by using generic
arr.validity()
.as_ref()
.map(|x| x.get_bit_unchecked(i))
.unwrap_or(true)
}

#[inline]
fn compare_with_nulls<A: Array, F: FnOnce() -> Ordering>(
left: &A,
right: &A,
i: usize,
j: usize,
nulls_equal: bool,
cmp: F,
) -> Ordering {
assert!(i < left.len());
assert!(j < right.len());
match (unsafe { is_valid(left, i) }, unsafe { is_valid(right, j) }) {
(true, true) => cmp(),
(false, true) => Ordering::Greater,
(true, false) => Ordering::Less,
(false, false) => {
if nulls_equal {
Ordering::Equal
} else {
Ordering::Less
}
}
}
}

#[allow(clippy::eq_op)]
#[inline]
fn cmp_float<F: Float>(l: &F, r: &F, nans_equal: bool) -> std::cmp::Ordering {
match (l.is_nan(), r.is_nan()) {
(false, false) => unsafe { l.partial_cmp(r).unwrap_unchecked() },
(true, true) => {
if nans_equal {
Ordering::Equal
} else {
Ordering::Less
}
}
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
}
}

fn compare_dyn_floats<T: NativeType + Float>(
nulls_equal: bool,
nans_equal: bool,
) -> DynArrayComparator {
Box::new(move |left, right, i, j| {
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
compare_with_nulls(left, right, i, j, nulls_equal, || {
cmp_float::<T>(
&unsafe { left.value_unchecked(i) },
&unsafe { right.value_unchecked(j) },
nans_equal,
)
})
})
}

fn compare_dyn_primitives<T: NativeType + Ord>(nulls_equal: bool) -> DynArrayComparator {
Box::new(move |left, right, i, j| {
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
compare_with_nulls(left, right, i, j, nulls_equal, || {
total_cmp(&unsafe { left.value_unchecked(i) }, &unsafe {
right.value_unchecked(j)
})
})
})
}

fn compare_dyn_string<O: Offset>(nulls_equal: bool) -> DynArrayComparator {
Box::new(move |left, right, i, j| {
let left = left.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
let right = right.as_any().downcast_ref::<Utf8Array<O>>().unwrap();
compare_with_nulls(left, right, i, j, nulls_equal, || {
unsafe { left.value_unchecked(i) }.cmp(unsafe { right.value_unchecked(j) })
})
})
}

fn compare_dyn_binary<O: Offset>(nulls_equal: bool) -> DynArrayComparator {
Box::new(move |left, right, i, j| {
let left = left.as_any().downcast_ref::<BinaryArray<O>>().unwrap();
let right = right.as_any().downcast_ref::<BinaryArray<O>>().unwrap();
compare_with_nulls(left, right, i, j, nulls_equal, || {
unsafe { left.value_unchecked(i) }.cmp(unsafe { right.value_unchecked(j) })
})
})
}

fn compare_dyn_boolean(nulls_equal: bool) -> DynArrayComparator {
Box::new(move |left, right, i, j| {
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap();
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap();
compare_with_nulls(left, right, i, j, nulls_equal, || {
unsafe { left.value_unchecked(i) }.cmp(unsafe { &right.value_unchecked(j) })
})
})
}

pub fn build_dyn_array_compare(
left: &DataType,
right: &DataType,
nulls_equal: bool,
nans_equal: bool,
) -> Result<DynArrayComparator> {
use DataType::*;
use IntervalUnit::*;
use TimeUnit::*;
Ok(match (left, right) {
(a, b) if a != b => {
return Err(Error::InvalidArgumentError(
"Can't compare arrays of different types".to_string(),
));
}
(Boolean, Boolean) => compare_dyn_boolean(nulls_equal),
(UInt8, UInt8) => compare_dyn_primitives::<u8>(nulls_equal),
(UInt16, UInt16) => compare_dyn_primitives::<u16>(nulls_equal),
(UInt32, UInt32) => compare_dyn_primitives::<u32>(nulls_equal),
(UInt64, UInt64) => compare_dyn_primitives::<u64>(nulls_equal),
(Int8, Int8) => compare_dyn_primitives::<i8>(nulls_equal),
(Int16, Int16) => compare_dyn_primitives::<i16>(nulls_equal),
(Int32, Int32)
| (Date32, Date32)
| (Time32(Second), Time32(Second))
| (Time32(Millisecond), Time32(Millisecond))
| (Interval(YearMonth), Interval(YearMonth)) => compare_dyn_primitives::<i32>(nulls_equal),
(Int64, Int64)
| (Date64, Date64)
| (Time64(Microsecond), Time64(Microsecond))
| (Time64(Nanosecond), Time64(Nanosecond))
| (Timestamp(Second, None), Timestamp(Second, None))
| (Timestamp(Millisecond, None), Timestamp(Millisecond, None))
| (Timestamp(Microsecond, None), Timestamp(Microsecond, None))
| (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None))
| (Duration(Second), Duration(Second))
| (Duration(Millisecond), Duration(Millisecond))
| (Duration(Microsecond), Duration(Microsecond))
| (Duration(Nanosecond), Duration(Nanosecond)) => {
compare_dyn_primitives::<i64>(nulls_equal)
}
(Float32, Float32) => compare_dyn_floats::<f32>(nulls_equal, nans_equal),
(Float64, Float64) => compare_dyn_floats::<f64>(nulls_equal, nans_equal),
(Decimal(_, _), Decimal(_, _)) => compare_dyn_primitives::<i128>(nulls_equal),
(Utf8, Utf8) => compare_dyn_string::<i32>(nulls_equal),
(LargeUtf8, LargeUtf8) => compare_dyn_string::<i64>(nulls_equal),
(Binary, Binary) => compare_dyn_binary::<i32>(nulls_equal),
(LargeBinary, LargeBinary) => compare_dyn_binary::<i64>(nulls_equal),
// (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => {
// match (key_type_lhs, key_type_rhs) {
// (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right),
// (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right),
// (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right),
// (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right),
// (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right),
// (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right),
// (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right),
// (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right),
// (lhs, _) => {
// return Err(Error::InvalidArgumentError(format!(
// "Dictionaries do not support keys of type {lhs:?}"
// )))
// }
// }
// }
(lhs, _) => {
return Err(Error::InvalidArgumentError(format!(
"The data type type {lhs:?} has no natural order"
)))
}
})
}
1 change: 1 addition & 0 deletions src/arrow2/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ mod fmt;
pub mod indexable;
mod iterator;

pub mod dyn_ord;
pub mod growable;
pub mod ord;

Expand Down
8 changes: 8 additions & 0 deletions src/daft-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ impl Schema {
self.fields.keys().cloned().collect()
}

pub fn len(&self) -> usize {
self.fields.len()
}

pub fn is_empty(&self) -> bool {
self.fields.is_empty()
}

pub fn union(&self, other: &Schema) -> DaftResult<Schema> {
let self_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys());
let other_keys: HashSet<&String> = HashSet::from_iter(self.fields.keys());
Expand Down
63 changes: 63 additions & 0 deletions src/daft-core/src/utils/dyn_compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::cmp::Ordering;

use crate::{schema::Schema, DataType};

use arrow2::array::Array;
use common_error::DaftError;
use common_error::DaftResult;

use arrow2::array::dyn_ord::build_dyn_array_compare;
use arrow2::array::dyn_ord::DynArrayComparator;

pub type MultiDynArrayComparator =
Box<dyn Fn(&[Box<dyn Array>], &[Box<dyn Array>], usize, usize) -> Ordering + Send + Sync>;

pub fn build_dyn_compare(
left: &DataType,
right: &DataType,
nulls_equal: bool,
nans_equal: bool,
) -> DaftResult<DynArrayComparator> {
if left != right {
Err(DaftError::TypeError(format!(
"Types do not match when creating comparator {} vs {}",
left, right
)))
} else {
Ok(build_dyn_array_compare(
&left.to_physical().to_arrow()?,
&right.to_physical().to_arrow()?,
nulls_equal,
nans_equal,
)?)
}
}

pub fn build_dyn_multi_array_compare(
schema: &Schema,
nulls_equal: bool,
nans_equal: bool,
) -> DaftResult<MultiDynArrayComparator> {
let mut fn_list = Vec::with_capacity(schema.len());
for field in schema.fields.values() {
fn_list.push(build_dyn_compare(
&field.dtype,
&field.dtype,
nulls_equal,
nans_equal,
)?);
}
let combined_fn = Box::new(
move |left: &[Box<dyn Array>], right: &[Box<dyn Array>], i: usize, j: usize| -> Ordering {
for (f, (l, r)) in fn_list.iter().zip(left.iter().zip(right.iter())) {
match f(l.as_ref(), r.as_ref(), i, j) {
std::cmp::Ordering::Equal => continue,
other => return other,
}
}
Ordering::Equal
},
);

Ok(combined_fn)
}
1 change: 1 addition & 0 deletions src/daft-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod arrow;
pub mod display_table;
pub mod dyn_compare;
pub mod hashable_float_wrapper;
pub mod supertype;

Expand Down
80 changes: 80 additions & 0 deletions src/daft-table/src/growable/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
array::growable::{make_growable, Growable},
Series,
};

use crate::Table;

pub struct GrowableTable<'a> {
growables: Vec<Box<dyn Growable + 'a>>,
}

impl<'a> GrowableTable<'a> {
pub fn new(tables: &[&'a Table], use_validity: bool, capacity: usize) -> DaftResult<Self> {
let num_tables = tables.len();
if tables.is_empty() {
return Err(DaftError::ValueError(
"Need at least 1 Table for GrowableTable".to_string(),
));
}

let first_table = tables.first().unwrap();
let num_columns = first_table.num_columns();
let first_schema = first_table.schema.as_ref();

let mut series_list = (0..num_columns)
.map(|_| Vec::<&Series>::with_capacity(num_tables))
.collect::<Vec<_>>();

for tab in tables {
if tab.schema.as_ref() != first_schema {
return Err(DaftError::SchemaMismatch(format!(
"GrowableTable requires all schemas to match, {} vs {}",
first_schema, tab.schema
)));
}
for (col, v) in tab.columns.iter().zip(series_list.iter_mut()) {
v.push(col);
}
}
let growables = series_list
.into_iter()
.zip(first_schema.fields.values())
.map(|(vector, f)| make_growable(&f.name, &f.dtype, vector, use_validity, capacity))
.collect::<Vec<_>>();
Ok(Self { growables })
}

/// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`.
pub fn extend(&mut self, index: usize, start: usize, len: usize) {
if !self.growables.is_empty() {
self.growables
.iter_mut()
.for_each(|g| g.extend(index, start, len))
}
}

/// Extends this [`Growable`] with null elements
pub fn add_nulls(&mut self, additional: usize) {
if !self.growables.is_empty() {
self.growables
.iter_mut()
.for_each(|g| g.add_nulls(additional))
}
}

/// Builds an array from the [`Growable`]
pub fn build(&mut self) -> DaftResult<Table> {
if self.growables.is_empty() {
Table::empty(None)
} else {
let columns = self
.growables
.iter_mut()
.map(|g| g.build())
.collect::<DaftResult<Vec<_>>>()?;
Table::from_nonempty_columns(columns)
}
}
}
Loading

0 comments on commit 230feb3

Please sign in to comment.