diff --git a/crates/proof-of-sql/src/utils/parquet_to_commitment_blob.rs b/crates/proof-of-sql/src/utils/parquet_to_commitment_blob.rs index 1fb08bed..f5fb3f56 100644 --- a/crates/proof-of-sql/src/utils/parquet_to_commitment_blob.rs +++ b/crates/proof-of-sql/src/utils/parquet_to_commitment_blob.rs @@ -4,13 +4,10 @@ use crate::{ }; use arrow::{ array::{ - Array, ArrayRef, ArrowPrimitiveType, BooleanArray, Decimal128Array, Decimal256Array, - Decimal256Builder, Int16Array, Int32Array, Int64Array, Int8Array, PrimitiveArray, - RecordBatch, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, + Array, ArrayRef, Decimal256Array, Decimal256Builder, Int32Array, RecordBatch, StringArray, }, compute::{sort_to_indices, take}, - datatypes::{i256, DataType, Field, Schema, TimeUnit}, + datatypes::{i256, DataType, Field, Schema}, error::ArrowError, }; use core::str::FromStr; @@ -23,23 +20,26 @@ use std::{collections::HashMap, fs::File, io::Write, path::PathBuf, sync::Arc}; static PARQUET_FILE_PROOF_ORDER_COLUMN: &str = "META_ROW_NUMBER"; /// Performs the following: -/// Reads a collection of parquet files which in aggregate represent a single table of data, -/// Calculates the `TableCommitment` for the table using multiple commitment strategies, -/// Serializes each commitment to a blob, which is saved in the same directory as the original parquet file +/// Reads a collection of parquet files which in aggregate represent a single range from a table of data, +/// Calculates the `TableCommitment` for the table using one or more commitment strategies, +/// Serializes each commitment to a blob /// /// # Panics /// /// Panics when any part of the process fails -pub fn read_parquet_file_to_commitment_as_blob( +pub fn convert_historical_parquet_file_to_commitment_blob( parquet_files: &Vec, output_path_prefix: &str, prover_setup: &DoryProverPublicSetup, big_decimal_columns: &[(String, u8, i8)], ) { + // Compute and collect TableCommitments per RecordBatch per file. let mut commitments: Vec> = parquet_files .par_iter() .flat_map(|path| { println!("Committing to {}..", path.as_path().to_str().unwrap()); + + // Collect RecordBatches from file let file = File::open(path).unwrap(); let reader = ParquetRecordBatchReaderBuilder::try_new(file) .unwrap() @@ -49,34 +49,34 @@ pub fn read_parquet_file_to_commitment_as_blob( let record_batches: Vec = record_batch_results .into_iter() .map(|record_batch_result| { + // Sorting can probably be removed sort_record_batch_by_meta_row_number(&record_batch_result.unwrap()) }) .collect(); + + // Compute and collect the TableCommitments for each RecordBatch in the file. let schema = record_batches.first().unwrap().schema(); - println!( - "File row COUNT: {}", - record_batches - .iter() - .map(RecordBatch::num_rows) - .sum::() - ); let commitments: Vec<_> = record_batches .into_par_iter() - .map(|mut unmodified_record_batch| { - let meta_row_number_column = unmodified_record_batch + .map(|mut record_batch| { + // We use the proof column only to identify the offset used to compute the commitments. It can be removed afterward. + let meta_row_number_column = record_batch .column_by_name(PARQUET_FILE_PROOF_ORDER_COLUMN) .unwrap() .as_any() .downcast_ref::() .unwrap(); - let offset = meta_row_number_column.value(0) - 1; - unmodified_record_batch + record_batch .remove_column(schema.index_of(PARQUET_FILE_PROOF_ORDER_COLUMN).unwrap()); - let record_batch = replace_nulls_within_record_batch(&correct_utf8_fields( - &unmodified_record_batch, + + // Replace appropriate string columns with decimal columns. + let record_batch = convert_utf8_to_decimal_75_within_record_batch_as_appropriate( + &record_batch, big_decimal_columns.to_vec(), - )); + ); + + // Calculate and return TableCommitment TableCommitment::::try_from_record_batch_with_offset( &record_batch, offset as usize, @@ -85,166 +85,24 @@ pub fn read_parquet_file_to_commitment_as_blob( .unwrap() }) .collect(); - println!("Commitments generated"); commitments }) .collect(); println!("done computing per-file commitments, now sorting and aggregating"); + + // We sort the TableCommitment collections in order to avoid non-contiguous errors. commitments.sort_by(|commitment_a, commitment_b| { commitment_a.range().start.cmp(&commitment_b.range().start) }); - //aggregate_commitments_to_blob(unzipped.0, format!("{output_path_prefix}-dory-commitment")); + // Sum commitments and write commitments to blob aggregate_commitments_to_blob( commitments, &format!("{output_path_prefix}-dory-commitment"), ); } -/// # Panics -/// -/// Panics when any part of the process fails -fn aggregate_commitments_to_blob Deserialize<'a>>( - commitments: Vec>, - output_file_base: &str, -) { - let commitment = commitments - .into_iter() - .fold( - None, - |aggregate_commitment: Option>, next_commitment| { - match aggregate_commitment { - Some(agg) => Some(agg.try_add(next_commitment).unwrap()), - None => Some(next_commitment), - } - }, - ) - .unwrap(); - write_commitment_to_blob(&commitment, output_file_base); -} - -/// # Panics -fn write_commitment_to_blob Deserialize<'a>>( - commitment: &TableCommitment, - output_file_base: &str, -) { - let bytes: Vec = to_allocvec(commitment).unwrap(); - let path_extension = "txt"; - let mut output_file = File::create(format!("{output_file_base}.{path_extension}")).unwrap(); - output_file.write_all(&bytes).unwrap(); -} - -fn replace_nulls_primitive(array: &PrimitiveArray) -> PrimitiveArray { - PrimitiveArray::from_iter_values( - array - .iter() - .map(|value: Option<::Native>| value.unwrap_or_default()), - ) -} - -/// # Panics -fn replace_nulls_within_record_batch(record_batch: &RecordBatch) -> RecordBatch { - let schema = record_batch.schema(); - let new_columns: Vec<_> = record_batch - .columns() - .iter() - .map(|column| { - if column.is_nullable() { - let column_type = column.data_type(); - let column: ArrayRef = match column_type { - DataType::Int8 => Arc::new(replace_nulls_primitive( - column.as_any().downcast_ref::().unwrap(), - )), - DataType::Int16 => Arc::new(replace_nulls_primitive( - column.as_any().downcast_ref::().unwrap(), - )), - DataType::Int32 => Arc::new(replace_nulls_primitive( - column.as_any().downcast_ref::().unwrap(), - )), - DataType::Int64 => Arc::new(replace_nulls_primitive( - column.as_any().downcast_ref::().unwrap(), - )), - - DataType::Decimal128(precision, scale) => Arc::new( - replace_nulls_primitive( - column.as_any().downcast_ref::().unwrap(), - ) - .with_precision_and_scale(*precision, *scale) - .unwrap(), - ), - DataType::Decimal256(precision, scale) => Arc::new( - replace_nulls_primitive( - column.as_any().downcast_ref::().unwrap(), - ) - .with_precision_and_scale(*precision, *scale) - .unwrap(), - ), - DataType::Timestamp(TimeUnit::Second, timezone) => Arc::new( - replace_nulls_primitive( - column - .as_any() - .downcast_ref::() - .unwrap(), - ) - .with_timezone_opt(timezone.clone()), - ), - DataType::Timestamp(TimeUnit::Millisecond, timezone) => Arc::new( - replace_nulls_primitive( - column - .as_any() - .downcast_ref::() - .unwrap(), - ) - .with_timezone_opt(timezone.clone()), - ), - DataType::Timestamp(TimeUnit::Microsecond, timezone) => Arc::new( - replace_nulls_primitive( - column - .as_any() - .downcast_ref::() - .unwrap(), - ) - .with_timezone_opt(timezone.clone()), - ), - DataType::Timestamp(TimeUnit::Nanosecond, timezone) => Arc::new( - replace_nulls_primitive( - column - .as_any() - .downcast_ref::() - .unwrap(), - ) - .with_timezone_opt(timezone.clone()), - ), - DataType::Boolean => Arc::new( - column - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|element| Some(element.unwrap_or(false))) - .collect::(), - ), - DataType::Utf8 => Arc::new(StringArray::from_iter_values( - column - .as_any() - .downcast_ref::() - .unwrap() - .iter() - .map(|element| element.unwrap_or("")), - )), - _ => unimplemented!(), - }; - - column - } else { - column.clone() - } - }) - .collect(); - RecordBatch::try_new(schema, new_columns).unwrap() -} - /// # Panics fn sort_record_batch_by_meta_row_number(record_batch: &RecordBatch) -> RecordBatch { let schema = record_batch.schema(); @@ -265,13 +123,10 @@ fn sort_record_batch_by_meta_row_number(record_batch: &RecordBatch) -> RecordBat } /// # Panics -fn cast_string_array_to_decimal256_array( - string_array: &[Option], - precision: u8, - scale: i8, -) -> Decimal256Array { - let mut builder = - Decimal256Builder::default().with_data_type(DataType::Decimal256(precision, scale)); +fn cast_string_array_to_decimal256_array(string_array: &StringArray, scale: i8) -> Decimal256Array { + let corrected_precision = 75; + let mut builder = Decimal256Builder::default() + .with_data_type(DataType::Decimal256(corrected_precision, scale)); string_array.iter().for_each(|value| match value { Some(v) => { @@ -286,7 +141,7 @@ fn cast_string_array_to_decimal256_array( } /// # Panics -fn correct_utf8_fields( +fn convert_utf8_to_decimal_75_within_record_batch_as_appropriate( record_batch: &RecordBatch, big_decimal_columns: Vec<(String, u8, i8)>, ) -> RecordBatch { @@ -305,23 +160,18 @@ fn correct_utf8_fields( let column = pointer_column.clone(); let column_name = field.name().to_lowercase(); if field.data_type() == &DataType::Utf8 { - let string_vec: Vec> = column + let string_array: StringArray = column .as_any() .downcast_ref::() .unwrap() - .into_iter() - .map(|s| s.map(|st| st.replace("\0", ""))) - .collect(); + .clone(); big_decimal_columns_lookup .get(&column_name) - .map(|(precision, scale)| { - Arc::new(cast_string_array_to_decimal256_array( - &string_vec, - *precision, - *scale, - )) as ArrayRef + .map(|(_precision, scale)| { + Arc::new(cast_string_array_to_decimal256_array(&string_array, *scale)) + as ArrayRef }) - .unwrap_or(Arc::new(StringArray::from(string_vec))) + .unwrap_or(Arc::new(string_array)) } else { Arc::new(column) } @@ -353,30 +203,52 @@ fn correct_utf8_fields( RecordBatch::try_new(new_schema.into(), columns).unwrap() } +/// # Panics +fn aggregate_commitments_to_blob Deserialize<'a>>( + commitments: Vec>, + output_file_base: &str, +) { + let commitment = commitments + .into_iter() + .fold( + None, + |aggregate_commitment: Option>, next_commitment| { + match aggregate_commitment { + Some(agg) => Some(agg.try_add(next_commitment).unwrap()), + None => Some(next_commitment), + } + }, + ) + .unwrap(); + write_commitment_to_blob(&commitment, output_file_base); +} + +/// # Panics +fn write_commitment_to_blob Deserialize<'a>>( + commitment: &TableCommitment, + output_file_base: &str, +) { + let bytes: Vec = to_allocvec(commitment).unwrap(); + let path_extension = "txt"; + let mut output_file = File::create(format!("{output_file_base}.{path_extension}")).unwrap(); + output_file.write_all(&bytes).unwrap(); +} + #[cfg(test)] mod tests { + use super::cast_string_array_to_decimal256_array; use crate::{ base::commitment::{Commitment, TableCommitment}, proof_primitive::dory::{ DoryCommitment, DoryProverPublicSetup, ProverSetup, PublicParameters, }, utils::parquet_to_commitment_blob::{ - correct_utf8_fields, read_parquet_file_to_commitment_as_blob, - replace_nulls_within_record_batch, PARQUET_FILE_PROOF_ORDER_COLUMN, + convert_historical_parquet_file_to_commitment_blob, PARQUET_FILE_PROOF_ORDER_COLUMN, }, }; use arrow::{ - array::{ - ArrayRef, ArrowPrimitiveType, BooleanArray, Decimal128Array, Decimal256Builder, - Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, - }, - datatypes::{ - i256, DataType, Decimal128Type, Field, Int16Type, Int32Type, Int64Type, Int8Type, - Schema, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, - }, + array::{ArrayRef, Decimal256Builder, Int32Array, RecordBatch, StringArray}, + datatypes::{i256, DataType}, }; use parquet::{arrow::ArrowWriter, basic::Compression, file::properties::WriterProperties}; use postcard::from_bytes; @@ -386,7 +258,6 @@ mod tests { use std::{ fs::{self, File}, io::Read, - panic, path::Path, sync::Arc, }; @@ -403,7 +274,7 @@ mod tests { writer.close().unwrap(); } - fn read_commitment_from_blob Deserialize<'a>>( + fn deserialize_commitment_from_file Deserialize<'a>>( path: &str, ) -> TableCommitment { let mut blob_file = File::open(path).unwrap(); @@ -419,413 +290,38 @@ mod tests { } #[test] - fn we_can_replace_nulls() { - let schema = Arc::new(Schema::new(vec![ - Field::new("utf8", DataType::Utf8, true), - Field::new("boolean", DataType::Boolean, true), - Field::new( - "timestamp_second", - DataType::Timestamp(arrow::datatypes::TimeUnit::Second, None), - true, - ), - Field::new( - "timestamp_millisecond", - DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, None), - true, - ), - Field::new( - "timestamp_microsecond", - DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), - true, - ), - Field::new( - "timestamp_nanosecond", - DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), - true, - ), - Field::new("decimal128", DataType::Decimal128(38, 10), true), - Field::new("int64", DataType::Int64, true), - Field::new("int32", DataType::Int32, true), - Field::new("int16", DataType::Int16, true), - Field::new("int8", DataType::Int8, true), - ])); - - let utf8 = Arc::new(StringArray::from(vec![ - Some("a"), - None, - Some("c"), - Some("d"), - None, - ])) as ArrayRef; - let utf8_denulled = Arc::new(StringArray::from(vec![ - Some("a"), - Some(""), - Some("c"), - Some("d"), - Some(""), - ])) as ArrayRef; - - let boolean = Arc::new(BooleanArray::from(vec![ - Some(true), - None, - Some(false), - Some(true), - None, - ])) as ArrayRef; - let boolean_denulled = Arc::new(BooleanArray::from(vec![ - Some(true), - Some(false), - Some(false), - Some(true), - Some(false), - ])) as ArrayRef; - - let timestamp_second = Arc::new(TimestampSecondArray::from(vec![ - Some(1_627_846_260), - None, - Some(1_627_846_262), - Some(1_627_846_263), - None, - ])) as ArrayRef; - let timestamp_second_denulled = Arc::new(TimestampSecondArray::from(vec![ - Some(1_627_846_260), - Some(TimestampSecondType::default_value()), - Some(1_627_846_262), - Some(1_627_846_263), - Some(TimestampSecondType::default_value()), - ])) as ArrayRef; - - let timestamp_millisecond = Arc::new(TimestampMillisecondArray::from(vec![ - Some(1_627_846_260_000), - None, - Some(1_627_846_262_000), - Some(1_627_846_263_000), - None, - ])) as ArrayRef; - let timestamp_millisecond_denulled = Arc::new(TimestampMillisecondArray::from(vec![ - Some(1_627_846_260_000), - Some(TimestampMillisecondType::default_value()), - Some(1_627_846_262_000), - Some(1_627_846_263_000), - Some(TimestampMillisecondType::default_value()), - ])) as ArrayRef; - - let timestamp_microsecond = Arc::new(TimestampMicrosecondArray::from(vec![ - Some(1_627_846_260_000_000), - None, - Some(1_627_846_262_000_000), - Some(1_627_846_263_000_000), - None, - ])) as ArrayRef; - let timestamp_microsecond_denulled = Arc::new(TimestampMicrosecondArray::from(vec![ - Some(1_627_846_260_000_000), - Some(TimestampMicrosecondType::default_value()), - Some(1_627_846_262_000_000), - Some(1_627_846_263_000_000), - Some(TimestampMicrosecondType::default_value()), - ])) as ArrayRef; - - let timestamp_nanosecond = Arc::new(TimestampNanosecondArray::from(vec![ - Some(1_627_846_260_000_000_000), - None, - Some(1_627_846_262_000_000_000), - Some(1_627_846_263_000_000_000), - None, - ])) as ArrayRef; - let timestamp_nanosecond_denulled = Arc::new(TimestampNanosecondArray::from(vec![ - Some(1_627_846_260_000_000_000), - Some(TimestampNanosecondType::default_value()), - Some(1_627_846_262_000_000_000), - Some(1_627_846_263_000_000_000), - Some(TimestampNanosecondType::default_value()), - ])) as ArrayRef; - - let decimal128 = Arc::new(Decimal128Array::from(vec![ - Some(12_345_678_901_234_567_890_i128), - None, - Some(23_456_789_012_345_678_901_i128), - Some(34_567_890_123_456_789_012_i128), - None, - ])) as ArrayRef; - let decimal128_denulled = Arc::new(Decimal128Array::from(vec![ - Some(12_345_678_901_234_567_890_i128), - Some(Decimal128Type::default_value()), - Some(23_456_789_012_345_678_901_i128), - Some(34_567_890_123_456_789_012_i128), - Some(Decimal128Type::default_value()), - ])) as ArrayRef; - - let int64 = Arc::new(Int64Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - None, - ])) as ArrayRef; - let int64_denulled = Arc::new(Int64Array::from(vec![ - Some(1), - Some(Int64Type::default_value()), - Some(3), - Some(4), - Some(Int64Type::default_value()), - ])) as ArrayRef; - - let int32 = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - None, - ])) as ArrayRef; - let int32_denulled = Arc::new(Int32Array::from(vec![ - Some(1), - Some(Int32Type::default_value()), - Some(3), - Some(4), - Some(Int32Type::default_value()), - ])) as ArrayRef; - - let int16 = Arc::new(Int16Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - None, - ])) as ArrayRef; - let int16_denulled = Arc::new(Int16Array::from(vec![ - Some(1), - Some(Int16Type::default_value()), - Some(3), - Some(4), - Some(Int16Type::default_value()), - ])) as ArrayRef; - - let int8 = - Arc::new(Int8Array::from(vec![Some(1), None, Some(3), Some(4), None])) as ArrayRef; - let int8_denulled = Arc::new(Int8Array::from(vec![ - Some(1), - Some(Int8Type::default_value()), - Some(3), - Some(4), - Some(Int8Type::default_value()), - ])) as ArrayRef; - - let record_batch = RecordBatch::try_new( - schema.clone(), - vec![ - utf8, - boolean, - timestamp_second, - timestamp_millisecond, - timestamp_microsecond, - timestamp_nanosecond, - decimal128, - int64, - int32, - int16, - int8, - ], - ) - .unwrap(); - let record_batch_denulled = RecordBatch::try_new( - schema, - vec![ - utf8_denulled, - boolean_denulled, - timestamp_second_denulled, - timestamp_millisecond_denulled, - timestamp_microsecond_denulled, - timestamp_nanosecond_denulled, - decimal128_denulled, - int64_denulled, - int32_denulled, - int16_denulled, - int8_denulled, - ], - ) - .unwrap(); - - let null_replaced_batch = replace_nulls_within_record_batch(&record_batch); - assert_eq!(null_replaced_batch, record_batch_denulled); - } - - #[test] - fn we_can_correct_utf8_columns() { - let original_schema = Arc::new(Schema::new(vec![ - Arc::new(Field::new("nullable_regular_string", DataType::Utf8, true)), - Arc::new(Field::new("nullable_big_decimal", DataType::Utf8, true)), - Arc::new(Field::new("not_null_regular_string", DataType::Utf8, false)), - Arc::new(Field::new("not_null_big_decimal", DataType::Utf8, false)), - Arc::new(Field::new("nullable_int", DataType::Int32, true)), - Arc::new(Field::new("not_null_int", DataType::Int32, false)), - ])); - let corrected_schema = Arc::new(Schema::new(vec![ - Arc::new(Field::new("nullable_regular_string", DataType::Utf8, true)), - Arc::new(Field::new( - "nullable_big_decimal", - DataType::Decimal256(25, 4), - true, - )), - Arc::new(Field::new("not_null_regular_string", DataType::Utf8, false)), - Arc::new(Field::new( - "not_null_big_decimal", - DataType::Decimal256(25, 4), - false, - )), - Arc::new(Field::new("nullable_int", DataType::Int32, true)), - Arc::new(Field::new("not_null_int", DataType::Int32, false)), - ])); - - let original_nullable_regular_string_array: ArrayRef = Arc::new(StringArray::from(vec![ - None, - Some("Bob"), - Some("Char\0lie"), - None, - Some("Eve"), - ])); - let corrected_nullable_regular_string_array: ArrayRef = Arc::new(StringArray::from(vec![ - None, - Some("Bob"), - Some("Charlie"), - None, - Some("Eve"), - ])); - let original_nullable_big_decimal_array: ArrayRef = Arc::new(StringArray::from(vec![ - Some("1234.56"), - None, - Some("45321E6"), - Some("123e4"), - None, - ])); - let mut corrected_nullable_big_decimal_array_builder = - Decimal256Builder::default().with_data_type(DataType::Decimal256(25, 4)); - corrected_nullable_big_decimal_array_builder.append_option(Some(i256::from(12_345_600))); - corrected_nullable_big_decimal_array_builder.append_null(); - corrected_nullable_big_decimal_array_builder - .append_option(Some(i256::from(453_210_000_000_000i64))); - corrected_nullable_big_decimal_array_builder - .append_option(Some(i256::from(12_300_000_000i64))); - corrected_nullable_big_decimal_array_builder.append_null(); - let corrected_nullable_big_decimal_array: ArrayRef = - Arc::new(corrected_nullable_big_decimal_array_builder.finish()); - let original_not_null_regular_string_array: ArrayRef = - Arc::new(StringArray::from(vec!["A", "B", "C\0", "D", "E"])); - let corrected_not_null_regular_string_array: ArrayRef = - Arc::new(StringArray::from(vec!["A", "B", "C", "D", "E"])); - let original_not_null_big_decimal_array: ArrayRef = - Arc::new(StringArray::from(vec!["1", "2.34", "5e6", "12", "1E4"])); - let mut corrected_not_null_big_decimal_array_builder = - Decimal256Builder::default().with_data_type(DataType::Decimal256(25, 4)); - corrected_not_null_big_decimal_array_builder.append_value(i256::from(10_000)); - corrected_not_null_big_decimal_array_builder.append_value(i256::from(23_400)); - corrected_not_null_big_decimal_array_builder.append_value(i256::from(50_000_000_000i64)); - corrected_not_null_big_decimal_array_builder.append_value(i256::from(120_000)); - corrected_not_null_big_decimal_array_builder.append_value(i256::from(100_000_000)); - let corrected_not_null_big_decimal_array: ArrayRef = - Arc::new(corrected_not_null_big_decimal_array_builder.finish()); - - let nullable_int_array: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(10), - None, - Some(30), - Some(40), - None, - ])); - let not_null_int_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - - let original_record_batch = RecordBatch::try_new( - original_schema, - vec![ - original_nullable_regular_string_array, - original_nullable_big_decimal_array, - original_not_null_regular_string_array, - original_not_null_big_decimal_array, - nullable_int_array.clone(), - not_null_int_array.clone(), - ], - ) - .unwrap(); - - let expected_corrected_record_batch = RecordBatch::try_new( - corrected_schema, - vec![ - corrected_nullable_regular_string_array, - corrected_nullable_big_decimal_array, - corrected_not_null_regular_string_array, - corrected_not_null_big_decimal_array, - nullable_int_array, - not_null_int_array, - ], - ) - .unwrap(); - - let big_decimal_columns = vec![ - ("nullable_big_decimal".to_string(), 25, 4), - ("not_null_big_decimal".to_string(), 25, 4), - ]; - let corrected_record_batch = - correct_utf8_fields(&original_record_batch, big_decimal_columns); - - assert_eq!(corrected_record_batch, expected_corrected_record_batch); - } - - #[test] - fn we_can_fail_if_datatype_of_big_decimal_column_is_not_string() { - let err = panic::catch_unwind(|| { - let string_array: ArrayRef = Arc::new(StringArray::from(vec![ - None, - Some("123"), - Some("345"), - None, - Some("567"), - ])); - let schema = Arc::new(Schema::new(vec![Arc::new(Field::new( - "nullable_big_decimal", - DataType::Int16, - true, - ))])); - let record_batch = RecordBatch::try_new(schema, vec![string_array]).unwrap(); - let big_decimal_columns = vec![("nullable_big_decimal".to_string(), 25, 4)]; - let _test = correct_utf8_fields(&record_batch, big_decimal_columns); - }); - assert!(err.is_err()); - } - - #[test] - fn we_can_fail_if_big_decimal_column_is_not_castable() { - let err = panic::catch_unwind(|| { - let string_array: ArrayRef = Arc::new(StringArray::from(vec![ - None, - Some("Bob"), - Some("Charlie"), - None, - Some("Eve"), - ])); - let schema = Arc::new(Schema::new(vec![Arc::new(Field::new( - "nullable_regular_string", - DataType::Utf8, - true, - ))])); - let record_batch = RecordBatch::try_new(schema, vec![string_array]).unwrap(); - let big_decimal_columns = vec![("nullable_regular_string".to_string(), 25, 4)]; - let _test = correct_utf8_fields(&record_batch, big_decimal_columns); - }); - assert!(err.is_err()); - } - - #[test] - fn we_can_retrieve_commitments_and_save_to_file() { + fn we_can_convert_historical_parquet_file_to_commitment_blob() { + // Purge any old files let parquet_path_1 = "example-1.parquet"; let parquet_path_2 = "example-2.parquet"; let dory_commitment_path = "example-dory-commitment.txt"; delete_file_if_exists(parquet_path_1); delete_file_if_exists(parquet_path_2); delete_file_if_exists(dory_commitment_path); + + // ARRANGE + + // Prepare prover setup + let setup_seed = "SpaceAndTime".to_string(); + let mut rng = { + let seed_bytes = setup_seed + .bytes() + .chain(std::iter::repeat(0u8)) + .take(32) + .collect::>() + .try_into() + .expect("collection is guaranteed to contain 32 elements"); + ChaCha20Rng::from_seed(seed_bytes) + }; + let public_parameters = PublicParameters::rand(4, &mut rng); + let prover_setup = ProverSetup::from(&public_parameters); + let dory_prover_setup: DoryProverPublicSetup = DoryProverPublicSetup::new(&prover_setup, 3); + + // Create two RecordBatches with the same schema let proof_column_1 = Int32Array::from(vec![1, 2]); let column_1 = Int32Array::from(vec![2, 1]); let proof_column_2 = Int32Array::from(vec![3, 4]); let column_2 = Int32Array::from(vec![3, 4]); - let column = Int32Array::from(vec![2, 1, 3, 4]); let record_batch_1 = RecordBatch::try_from_iter(vec![ ( PARQUET_FILE_PROOF_ORDER_COLUMN, @@ -842,41 +338,58 @@ mod tests { ("column", Arc::new(column_2) as ArrayRef), ]) .unwrap(); - let record_batch = - RecordBatch::try_from_iter(vec![("column", Arc::new(column) as ArrayRef)]).unwrap(); + + // Write RecordBatches to parquet files create_mock_file_from_record_batch(parquet_path_1, &record_batch_1); create_mock_file_from_record_batch(parquet_path_2, &record_batch_2); - let setup_seed = "SpaceAndTime".to_string(); - let mut rng = { - // Convert the seed string to bytes and create a seeded RNG - let seed_bytes = setup_seed - .bytes() - .chain(std::iter::repeat(0u8)) - .take(32) - .collect::>() - .try_into() - .expect("collection is guaranteed to contain 32 elements"); - ChaCha20Rng::from_seed(seed_bytes) // Seed ChaChaRng - }; - let public_parameters = PublicParameters::rand(4, &mut rng); - let prover_setup = ProverSetup::from(&public_parameters); - let dory_prover_setup: DoryProverPublicSetup = DoryProverPublicSetup::new(&prover_setup, 3); - read_parquet_file_to_commitment_as_blob( + + // ACT + convert_historical_parquet_file_to_commitment_blob( &vec![parquet_path_1.into(), parquet_path_2.into()], "example", &dory_prover_setup, &Vec::new(), ); + + // ASSERT + + // Identify expected commitments + let expected_column = Int32Array::from(vec![2, 1, 3, 4]); + let expected_record_batch = + RecordBatch::try_from_iter(vec![("column", Arc::new(expected_column) as ArrayRef)]).unwrap(); + let expected_commitment = TableCommitment::::try_from_record_batch( + &expected_record_batch, + &dory_prover_setup, + ) + .unwrap(); + assert_eq!( - read_commitment_from_blob::(dory_commitment_path), - TableCommitment::::try_from_record_batch( - &record_batch, - &dory_prover_setup - ) - .unwrap() + deserialize_commitment_from_file::(dory_commitment_path), + expected_commitment ); + + // Tear down delete_file_if_exists(parquet_path_1); delete_file_if_exists(parquet_path_2); delete_file_if_exists(dory_commitment_path); } + + #[test] + fn we_can_cast_string_array_to_decimal_75() { + // ARRANGE + let string_array: StringArray = + StringArray::from(vec![Some("123.45"), None, Some("234.56"), Some("789.01")]); + + // ACT + let decimal_75_array = cast_string_array_to_decimal256_array(&string_array, 2); + + // ASSERT + let mut expected_decimal_75_array = + Decimal256Builder::default().with_data_type(DataType::Decimal256(75, 2)); + expected_decimal_75_array.append_value(i256::from(12_345)); + expected_decimal_75_array.append_null(); + expected_decimal_75_array.append_value(i256::from(23_456)); + expected_decimal_75_array.append_value(i256::from(78_901)); + assert_eq!(decimal_75_array, expected_decimal_75_array.finish()); + } } diff --git a/scripts/parquet-to-commitments/src/main.rs b/scripts/parquet-to-commitments/src/main.rs index eea50625..d0711f46 100644 --- a/scripts/parquet-to-commitments/src/main.rs +++ b/scripts/parquet-to-commitments/src/main.rs @@ -8,7 +8,7 @@ use glob::glob; use proof_of_sql::{ proof_primitive::dory::{DoryProverPublicSetup, ProverSetup, PublicParameters}, utils::{ - parquet_to_commitment_blob::read_parquet_file_to_commitment_as_blob, + parquet_to_commitment_blob::convert_historical_parquet_file_to_commitment_blob, parse::find_bigdecimals, }, }; @@ -96,7 +96,7 @@ fn main() { let full_output_prefix = format!("{output_prefix}-{namespace}-{table_name}"); let result = panic::catch_unwind(|| { - read_parquet_file_to_commitment_as_blob( + convert_historical_parquet_file_to_commitment_blob( &parquets_for_table, &full_output_prefix, &dory_prover_setup,