Skip to content

Commit

Permalink
Make 25% faster
Browse files Browse the repository at this point in the history
  • Loading branch information
AlSchlo committed Apr 29, 2024
1 parent 0d060b1 commit 358ccc4
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 177 deletions.
92 changes: 47 additions & 45 deletions optd-datafusion-repr/src/cost/base_cost/stats.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::{collections::HashMap, sync::Arc};
use std::{
collections::HashMap,
sync::{mpsc::Receiver, Arc},
thread::JoinHandle,
};

use arrow_schema::{ArrowError, DataType, SchemaRef};
use arrow_schema::{ArrowError, DataType, Schema, SchemaRef};
use datafusion::arrow::array::{
Array, BooleanArray, Date32Array, Float32Array, Int16Array, Int32Array, Int8Array, RecordBatch,
RecordBatchIterator, RecordBatchReader, StringArray, UInt16Array, UInt32Array, UInt8Array,
StringArray, UInt16Array, UInt32Array, UInt8Array,
};
use itertools::Itertools;
use optd_core::rel_node::{SerializableOrderedF64, Value};
Expand Down Expand Up @@ -330,16 +334,17 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
.zip(hlls)
.zip(null_counts)
.for_each(|(((column_comb, mg), hll), count)| {
let filtered_nulls: Vec<ColumnCombValue> = column_comb
.iter()
.filter(|row| row.iter().any(|val| val.is_some()))
.cloned()
.collect();
let nb_rows = column_comb.len() as i32;
let filtered_nulls = column_comb
.into_iter()
.filter(|row| row.iter().any(|val| val.is_some()));

*count += column_comb.len() as i32;

*count += nb_rows - filtered_nulls.len() as i32;
mg.aggregate(&filtered_nulls);
hll.aggregate(&filtered_nulls);
filtered_nulls.for_each(|e| {
mg.insert_element(e, 1);
hll.process(e);
*count -= 1;
});
});
}

Expand Down Expand Up @@ -373,46 +378,38 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
});
}

pub fn from_record_batches<I: IntoIterator<Item = Result<RecordBatch, ArrowError>>>(
batch_iter_builder: impl Fn() -> anyhow::Result<RecordBatchIterator<I>>,
pub fn from_record_batches(
first_batch_channel: impl FnOnce()
-> (JoinHandle<()>, Receiver<Result<RecordBatch, ArrowError>>),
second_batch_channel: impl FnOnce()
-> (JoinHandle<()>, Receiver<Result<RecordBatch, ArrowError>>),
combinations: Vec<ColumnsIdx>,
schema: Arc<Schema>,
) -> anyhow::Result<Self> {
let batch_iter = batch_iter_builder()?;
let comb_stat_types = Self::get_stats_types(&combinations, &batch_iter.schema());
let comb_stat_types = Self::get_stats_types(&combinations, &schema);
let nb_stats = comb_stat_types.len();

// 0. Just count row numbers if no combinations can give stats.
if nb_stats == 0 {
let mut row_cnt = 0;
for batch in batch_iter {
row_cnt += batch?.num_rows();
}

return Ok(Self {
row_cnt,
column_comb_stats: HashMap::new(),
});
}

// TODO(Alexis): This materialization is OK as JOB only takes 1GB, but should be made in parallel...
// Unfortunately, par_bridge doesn't work as the BatchIterator doesn't implement Send.
let materialized: Vec<_> = batch_iter.collect();

// 1. FIRST PASS: hlls + mgs + null_cnts.
// 1. FIRST PASS: hlls + mgs + null_cnts.
let now = std::time::Instant::now();
let (hlls, mgs, null_cnts) = materialized
.par_iter()
let (handle, receiver) = first_batch_channel();

let (hlls, mgs, null_cnts) = receiver
.into_iter()
.par_bridge()
.fold(Self::first_pass_stats_id(nb_stats), |local_stats, batch| {
let mut local_stats = local_stats?;

match batch {
Ok(batch) => {
let (hlls, mgs, null_cnts) = &mut local_stats;
let comb = Self::get_column_combs(batch, &comb_stat_types);
let comb = Self::get_column_combs(&batch, &comb_stat_types);
Self::generate_partial_stats(&comb, mgs, hlls, null_cnts);
Ok(local_stats)
}
Err(_) => todo!(), // TODO(Alexis): Could not satisfy the type checker otherwise, but never happens!
Err(e) => {
println!("Err: {:?},, {:?}", e, comb_stat_types.len());
Err(e.into())
}
}
})
.reduce(
Expand All @@ -433,12 +430,17 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
Ok(final_stats)
},
)?;

let _ = handle.join();
let first = now.elapsed();

// 2. SECOND PASS: mcv + tdigest + row_cnts.
let now = std::time::Instant::now();
let (distrs, cnts, row_cnts) = materialized
.par_iter()
let (handle, receiver) = second_batch_channel();

let (distrs, cnts, row_cnts) = receiver
.into_iter()
.par_bridge()
.fold(
Self::second_pass_stats_id(&comb_stat_types, &mgs, nb_stats),
|local_stats, batch| {
Expand All @@ -447,11 +449,11 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
match batch {
Ok(batch) => {
let (distrs, cnts, row_cnts) = &mut local_stats;
let comb = Self::get_column_combs(batch, &comb_stat_types);
let comb = Self::get_column_combs(&batch, &comb_stat_types);
Self::generate_full_stats(&comb, cnts, distrs, row_cnts);
Ok(local_stats)
}
Err(_) => todo!(), // TODO(Alexis): Could not satisfy the type checker otherwise, but never happens!
Err(e) => Err(e.into()),
}
},
)
Expand Down Expand Up @@ -479,10 +481,11 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
Ok(final_stats)
},
)?;
let second = now.elapsed();

let _ = handle.join();
println!("First: {:?}, Second: {:?}", first, now.elapsed());

// 3. ASSEMBLE STATS.
let now = std::time::Instant::now();
let row_cnt = row_cnts[0];
let mut column_comb_stats = HashMap::new();

Expand All @@ -506,7 +509,6 @@ impl TableStats<Counter<ColumnCombValue>, TDigest<Value>> {
);
column_comb_stats.insert(comb, column_stats);
}
println!("First: {:?}, Second: {:?}, Third: {:?}", first, second, now.elapsed());

Ok(Self {
row_cnt: row_cnt as usize,
Expand Down
32 changes: 19 additions & 13 deletions optd-gungnir/src/stats/hyperloglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ impl_byte_serializable_for_numeric!(usize, isize);
impl_byte_serializable_for_numeric!(f64, f32);

// Self-contained implementation of the HyperLogLog data structure.
impl<T> HyperLogLog<T>
impl<'a, T> HyperLogLog<T>
where
T: ByteSerializable,
T: ByteSerializable + 'a,
{
/// Creates and initializes a new empty HyperLogLog.
pub fn new(precision: u8) -> Self {
Expand All @@ -109,17 +109,23 @@ where
}
}

pub fn process(&mut self, element: &T)
where
T: ByteSerializable,
{
let hash = murmur_hash(&element.to_bytes(), 0); // TODO: We ignore DoS attacks (seed).
let mask = (1 << (self.precision)) - 1;
let idx = (hash & mask) as usize; // LSB is bucket discriminator; MSB is zero streak.
self.registers[idx] = max(self.registers[idx], self.zeros(hash) + 1);
}

/// Digests an array of ByteSerializable data into the HLL.
pub fn aggregate(&mut self, data: &[T])
pub fn aggregate<I>(&mut self, data: I)
where
I: Iterator<Item = &'a T>,
T: ByteSerializable,
{
for d in data {
let hash = murmur_hash(&d.to_bytes(), 0); // TODO: We ignore DoS attacks (seed).
let mask = (1 << (self.precision)) - 1;
let idx = (hash & mask) as usize; // LSB is bucket discriminator; MSB is zero streak.
self.registers[idx] = max(self.registers[idx], self.zeros(hash) + 1);
}
data.for_each(|e| self.process(e));
}

/// Merges two HLLs together and returns a new one.
Expand Down Expand Up @@ -192,7 +198,7 @@ mod tests {
let mut hll = HyperLogLog::new(12);

let data = vec!["a".to_string(), "b".to_string()];
hll.aggregate(&data);
hll.aggregate(data.iter());
assert_eq!(hll.n_distinct(), data.len() as u64);
}

Expand All @@ -201,7 +207,7 @@ mod tests {
let mut hll = HyperLogLog::new(12);

let data = vec![1, 2];
hll.aggregate(&data);
hll.aggregate(data.iter());
assert_eq!(hll.n_distinct(), data.len() as u64);
}

Expand Down Expand Up @@ -239,7 +245,7 @@ mod tests {
let relative_error = 0.05; // We allow a 5% relatative error rate.

let strings = generate_random_strings(n_distinct, 100, 0);
hll.aggregate(&strings);
hll.aggregate(strings.iter());

assert!(is_close(
hll.n_distinct() as f64,
Expand All @@ -264,7 +270,7 @@ mod tests {
let curr_job_id = job_id.fetch_add(1, Ordering::SeqCst);

let strings = generate_random_strings(n_distinct, 100, curr_job_id);
local_hll.aggregate(&strings);
local_hll.aggregate(strings.iter());

assert!(is_close(
local_hll.n_distinct() as f64,
Expand Down
21 changes: 12 additions & 9 deletions optd-gungnir/src/stats/misragries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ pub struct MisraGries<T: PartialEq + Eq + Hash + Clone> {
}

// Self-contained implementation of the Misra-Gries data structure.
impl<T> MisraGries<T>
impl<'a, T> MisraGries<T>
where
T: PartialEq + Eq + Hash + Clone,
T: PartialEq + Eq + Hash + Clone + 'a,
{
/// Creates and initializes a new empty Misra-Gries.
pub fn new(k: u16) -> Self {
Expand All @@ -48,7 +48,7 @@ where
}

// Inserts an element occ times into the `self` Misra-Gries structure.
fn insert_element(&mut self, elem: &T, occ: i32) {
pub fn insert_element(&mut self, elem: &T, occ: i32) {
match self.frequencies.get_mut(elem) {
Some(freq) => {
*freq += occ; // Hit.
Expand Down Expand Up @@ -93,8 +93,11 @@ where
}

/// Digests an array of data into the Misra-Gries structure.
pub fn aggregate(&mut self, data: &[T]) {
data.iter().for_each(|key| self.insert_element(key, 1));
pub fn aggregate<I>(&mut self, data: I)
where
I: Iterator<Item = &'a T>,
{
data.for_each(|key| self.insert_element(&key, 1));
}

/// Merges another MisraGries into the current one.
Expand Down Expand Up @@ -131,7 +134,7 @@ mod tests {
let data = vec![0, 1, 2, 3];
let mut misra_gries = MisraGries::<i32>::new(data.len() as u16);

misra_gries.aggregate(&data);
misra_gries.aggregate(data.iter());

for key in misra_gries.most_frequent_keys() {
assert!(data.contains(key));
Expand All @@ -145,7 +148,7 @@ mod tests {

let mut misra_gries = MisraGries::<i32>::new(data.len() as u16);

misra_gries.aggregate(&data_dup);
misra_gries.aggregate(data_dup.iter());

for key in misra_gries.most_frequent_keys() {
assert!(data.contains(key));
Expand Down Expand Up @@ -189,7 +192,7 @@ mod tests {
let data = create_zipfian(n_distinct, 0);
let mut misra_gries = MisraGries::<i32>::new(k as u16);

misra_gries.aggregate(&data);
misra_gries.aggregate(data.iter());

check_zipfian(&misra_gries, n_distinct);
}
Expand All @@ -209,7 +212,7 @@ mod tests {
let curr_job_id = job_id.fetch_add(1, Ordering::SeqCst);

let data = create_zipfian(n_distinct, curr_job_id as u64);
local_misra_gries.aggregate(&data);
local_misra_gries.aggregate(data.iter());

check_zipfian(&local_misra_gries, n_distinct);

Expand Down
Loading

0 comments on commit 358ccc4

Please sign in to comment.