From b1a01aea5e4e77a90467586c13864678a4db5819 Mon Sep 17 00:00:00 2001 From: Ian Alexander Joiner <14581281+iajoiner@users.noreply.github.com> Date: Wed, 19 Jun 2024 21:15:36 -0400 Subject: [PATCH] fix!: separate element decoding error from verification error (#16) --- .../src/sql/proof/provable_query_result.rs | 36 ++++----- .../sql/proof/provable_query_result_test.rs | 75 ++++++++++-------- .../proof-of-sql/src/sql/proof/query_proof.rs | 9 +-- .../src/sql/proof/query_result.rs | 9 +++ .../sql/proof/result_element_serialization.rs | 77 +++++++++++-------- 5 files changed, 109 insertions(+), 97 deletions(-) diff --git a/crates/proof-of-sql/src/sql/proof/provable_query_result.rs b/crates/proof-of-sql/src/sql/proof/provable_query_result.rs index ad2e1c678..5ef72846e 100644 --- a/crates/proof-of-sql/src/sql/proof/provable_query_result.rs +++ b/crates/proof-of-sql/src/sql/proof/provable_query_result.rs @@ -83,11 +83,11 @@ impl ProvableQueryResult { evaluation_point: &[S], table_length: usize, column_result_fields: &[ColumnField], - ) -> Option> { + ) -> Result, QueryError> { assert_eq!(self.num_columns as usize, column_result_fields.len()); if !self.indexes.valid(table_length) { - return None; + return Err(QueryError::InvalidIndexes); } let evaluation_vec_len = self @@ -119,7 +119,6 @@ impl ProvableQueryResult { decode_and_convert::(&self.data[offset..]) } }?; - val += evaluation_vec[index as usize] * x; offset += sz; } @@ -127,10 +126,10 @@ impl ProvableQueryResult { } if offset != self.data.len() { - return None; + return Err(QueryError::MiscellaneousEvaluationError); } - Some(res) + Ok(res) } /// Convert the intermediate query result into a final query result @@ -150,56 +149,47 @@ impl ProvableQueryResult { .iter() .map(|field| match field.data_type() { ColumnType::Boolean => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::Boolean(col))) } ColumnType::SmallInt => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::SmallInt(col))) } ColumnType::Int => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::Int(col))) } ColumnType::BigInt => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::BigInt(col))) } ColumnType::Int128 => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::Int128(col))) } ColumnType::VarChar => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::InvalidString)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::VarChar(col))) } ColumnType::Scalar => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::Scalar(col))) } ColumnType::Decimal75(precision, scale) => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::Decimal75(precision, scale, col))) } ColumnType::TimestampTZ(tu, tz) => { - let (col, num_read) = decode_multiple_elements(&self.data[offset..], n) - .ok_or(QueryError::Overflow)?; + let (col, num_read) = decode_multiple_elements(&self.data[offset..], n)?; offset += num_read; Ok((field.name(), OwnedColumn::TimestampTZ(tu, tz, col))) } diff --git a/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs b/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs index f72c99e03..9c675f3ba 100644 --- a/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs +++ b/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs @@ -1,4 +1,4 @@ -use super::{ProvableQueryResult, ProvableResultColumn}; +use super::{ProvableQueryResult, ProvableResultColumn, QueryError}; use crate::{ base::{ database::{ColumnField, ColumnType}, @@ -206,9 +206,10 @@ fn evaluation_fails_if_indexes_are_out_of_range() { compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); let column_fields = vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; - assert!(res - .evaluate(&evaluation_point, 4, &column_fields[..]) - .is_none()); + assert!(matches!( + res.evaluate(&evaluation_point, 4, &column_fields[..]), + Err(QueryError::InvalidIndexes) + )); } #[test] @@ -225,9 +226,10 @@ fn evaluation_fails_if_indexes_are_not_sorted() { compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); let column_fields = vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; - assert!(res - .evaluate(&evaluation_point, 4, &column_fields[..]) - .is_none()); + assert!(matches!( + res.evaluate(&evaluation_point, 4, &column_fields[..]), + Err(QueryError::InvalidIndexes) + )); } #[test] @@ -245,9 +247,10 @@ fn evaluation_fails_if_extra_data_is_included() { compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); let column_fields = vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; - assert!(res - .evaluate(&evaluation_point, 4, &column_fields[..]) - .is_none()); + assert!(matches!( + res.evaluate(&evaluation_point, 4, &column_fields[..]), + Err(QueryError::MiscellaneousEvaluationError) + )); } #[test] @@ -266,9 +269,30 @@ fn evaluation_fails_if_the_result_cant_be_decoded() { compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); let column_fields = vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); res.num_columns()]; - assert!(res - .evaluate(&evaluation_point, 4, &column_fields[..]) - .is_none()); + assert!(matches!( + res.evaluate(&evaluation_point, 4, &column_fields[..]), + Err(QueryError::Overflow) + )); +} + +#[test] +fn evaluation_fails_if_integer_overflow_happens() { + let indexes = Indexes::Sparse(vec![0, 2]); + let values: [i64; 3] = [i32::MAX as i64 + 1_i64, 11, 12]; + let cols: [Box; 1] = [Box::new(values)]; + let res = ProvableQueryResult::new(&indexes, &cols); + let evaluation_point = [ + Curve25519Scalar::from(10u64), + Curve25519Scalar::from(100u64), + ]; + let mut evaluation_vec = [Curve25519Scalar::ZERO; 4]; + compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); + let column_fields = + vec![ColumnField::new("a".parse().unwrap(), ColumnType::Int); res.num_columns()]; + assert!(matches!( + res.evaluate(&evaluation_point, 4, &column_fields[..]), + Err(QueryError::Overflow) + )); } #[test] @@ -286,9 +310,10 @@ fn evaluation_fails_if_data_is_missing() { compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); let column_fields = vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); res.num_columns()]; - assert!(res - .evaluate(&evaluation_point, 4, &column_fields[..]) - .is_none()); + assert!(matches!( + res.evaluate(&evaluation_point, 4, &column_fields[..]), + Err(QueryError::Overflow) + )); } #[test] @@ -406,7 +431,6 @@ fn we_can_convert_a_provable_result_to_a_final_result_with_mixed_data_types() { .unwrap(); let column_fields: Vec = column_fields.iter().map(|v| v.into()).collect(); let schema = Arc::new(Schema::new(column_fields)); - println!("{:?}", res); let expected_res = RecordBatch::try_new( schema, vec![ @@ -446,20 +470,3 @@ fn we_cannot_convert_a_provable_result_with_invalid_string_data() { .to_owned_table::(&column_fields) .is_err()); } - -// TODO: we don't correctly detect overflow yet -// #[test] -// #[should_panic] -// fn we_can_detect_overflow() { -// let indexes = [0]; -// let values = [i64::MAX]; -// let cols : [Box; 1] = [ -// Box::new(values), -// ]; -// let res = ProvableQueryResult::new(& -// &indexes, -// &cols, -// ); -// let column_fields = vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)]; -// let res = res.into_query_result(&column_fields).unwrap(); -// } diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 0d364738a..4e469ec12 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -242,16 +242,11 @@ impl QueryProof { let column_result_fields = expr.get_column_result_fields(); // compute the evaluation of the result MLEs - let result_evaluations = match result.evaluate( + let result_evaluations = result.evaluate( &subclaim.evaluation_point, table_length, &column_result_fields[..], - ) { - Some(evaluations) => evaluations, - _ => Err(ProofError::VerificationError( - "failed to evaluate intermediate result MLEs", - ))?, - }; + )?; // pass over the provable AST to fill in the verification builder let sumcheck_evaluations = SumcheckMleEvaluations::new( diff --git a/crates/proof-of-sql/src/sql/proof/query_result.rs b/crates/proof-of-sql/src/sql/proof/query_result.rs index 8165688ac..052b9bad4 100644 --- a/crates/proof-of-sql/src/sql/proof/query_result.rs +++ b/crates/proof-of-sql/src/sql/proof/query_result.rs @@ -17,6 +17,15 @@ pub enum QueryError { /// This just means that the database was supposed to respond with a string that was not valid UTF-8. #[error("String decode error")] InvalidString, + /// Decoding errors other than overflow and invalid string. + #[error("Miscellaneous decoding error")] + MiscellaneousDecodingError, + /// Indexes are invalid. + #[error("Invalid indexes")] + InvalidIndexes, + /// Miscellaneous evaluation error. + #[error("Miscellaneous evaluation error")] + MiscellaneousEvaluationError, /// The proof failed to verify. #[error(transparent)] ProofError(#[from] ProofError), diff --git a/crates/proof-of-sql/src/sql/proof/result_element_serialization.rs b/crates/proof-of-sql/src/sql/proof/result_element_serialization.rs index c6b94f2dd..824e1070e 100644 --- a/crates/proof-of-sql/src/sql/proof/result_element_serialization.rs +++ b/crates/proof-of-sql/src/sql/proof/result_element_serialization.rs @@ -1,10 +1,11 @@ +use super::QueryError; use crate::base::encode::VarInt; pub trait ProvableResultElement<'a> { fn required_bytes(&self) -> usize; fn encode(&self, out: &mut [u8]) -> usize; - fn decode(data: &'a [u8]) -> Option<(Self, usize)> + fn decode(data: &'a [u8]) -> Result<(Self, usize), QueryError> where Self: Sized; } @@ -20,8 +21,8 @@ impl ProvableResultElement<'_> for T { self.encode_var(out) } - fn decode(data: &[u8]) -> Option<(Self, usize)> { - VarInt::decode_var(data) + fn decode(data: &[u8]) -> Result<(Self, usize), QueryError> { + VarInt::decode_var(data).ok_or(QueryError::Overflow) } } @@ -39,16 +40,17 @@ impl<'a> ProvableResultElement<'a> for &'a [u8] { bytes_written } - fn decode(data: &'a [u8]) -> Option<(Self, usize)> { - let (len_buf, sizeof_usize) = ::decode_var(data)?; + fn decode(data: &'a [u8]) -> Result<(Self, usize), QueryError> { + let (len_buf, sizeof_usize) = + ::decode_var(data).ok_or(QueryError::MiscellaneousDecodingError)?; let bytes_read = len_buf + sizeof_usize; if data.len() < bytes_read { - return None; + return Err(QueryError::MiscellaneousDecodingError); } - Some((&data[sizeof_usize..bytes_read], bytes_read)) + Ok((&data[sizeof_usize..bytes_read], bytes_read)) } } @@ -60,7 +62,7 @@ impl<'a> ProvableResultElement<'a> for &'a str { fn encode(&self, out: &mut [u8]) -> usize { self.as_bytes().encode(out) } - fn decode(data: &'a [u8]) -> Option<(Self, usize)> { + fn decode(data: &'a [u8]) -> Result<(Self, usize), QueryError> { let (data, bytes_read) = <&[u8]>::decode(data)?; // arrow::array::StringArray only supports strings @@ -69,10 +71,13 @@ impl<'a> ProvableResultElement<'a> for &'a str { // StringArray will panic. So we add this restriction here to // prevent this scenario. if data.len() > i32::MAX as usize { - return None; + return Err(QueryError::MiscellaneousDecodingError); } - Some((std::str::from_utf8(data).ok()?, bytes_read)) + Ok(( + std::str::from_utf8(data).map_err(|_e| QueryError::InvalidString)?, + bytes_read, + )) } } @@ -84,25 +89,25 @@ impl ProvableResultElement<'_> for String { fn encode(&self, out: &mut [u8]) -> usize { self.as_str().encode(out) } - fn decode(data: &[u8]) -> Option<(Self, usize)> { + fn decode(data: &[u8]) -> Result<(Self, usize), QueryError> { decode_and_convert::<&str, String>(data) } } -pub fn decode_and_convert<'a, F, T>(data: &'a [u8]) -> Option<(T, usize)> +pub fn decode_and_convert<'a, F, T>(data: &'a [u8]) -> Result<(T, usize), QueryError> where F: ProvableResultElement<'a>, T: From, { let (val, num_read) = F::decode(data)?; - Some((val.into(), num_read)) + Ok((val.into(), num_read)) } /// Implement the decode operation for multiple rows pub fn decode_multiple_elements<'a, T: ProvableResultElement<'a>>( data: &'a [u8], n: usize, -) -> Option<(Vec, usize)> { +) -> Result<(Vec, usize), QueryError> { let mut res = Vec::with_capacity(n); let mut cnt = 0; for _ in 0..n { @@ -112,7 +117,7 @@ pub fn decode_multiple_elements<'a, T: ProvableResultElement<'a>>( cnt += num_read; } - Some((res, cnt)) + Ok((res, cnt)) } #[cfg(test)] @@ -175,12 +180,18 @@ mod tests { let value = Curve25519Scalar::from(i128::MAX) + Curve25519Scalar::from(1); let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); - assert_eq!(::decode(&out[..]), None); + assert!(matches!( + ::decode(&out[..]), + Err(QueryError::Overflow) + )); let value = Curve25519Scalar::from(i128::MIN) - Curve25519Scalar::from(1); let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); - assert_eq!(::decode(&out[..]), None); + assert!(matches!( + ::decode(&out[..]), + Err(QueryError::Overflow) + )); } #[test] @@ -392,8 +403,8 @@ mod tests { let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); - assert!(::decode(&out[..]).is_some()); - assert!(::decode(&[]).is_none()); + assert!(::decode(&out[..]).is_ok()); + assert!(::decode(&[]).is_err()); } #[test] @@ -402,11 +413,11 @@ mod tests { let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); - assert!(::decode(&out[..]).is_some()); + assert!(::decode(&out[..]).is_ok()); out[..].clone_from_slice(&vec![0b11111111; value.required_bytes()]); - assert!(::decode(&out[..]).is_none()); + assert!(::decode(&out[..]).is_err()); } #[test] @@ -415,11 +426,11 @@ mod tests { let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); - assert!(<&str>::decode(&out[..]).is_some()); + assert!(<&str>::decode(&out[..]).is_ok()); let last_element = out.len(); out[last_element - 3..last_element].clone_from_slice(&[0xed, 0xa0, 0x80]); - assert!(<&str>::decode(&out[..]).is_none()); + assert!(<&str>::decode(&out[..]).is_err()); } #[test] @@ -428,7 +439,7 @@ mod tests { let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); assert_eq!(out.len(), value.len().required_space()); - assert!(<&[u8]>::decode(&out[..0]).is_none()); + assert!(<&[u8]>::decode(&out[..0]).is_err()); } #[test] @@ -438,14 +449,14 @@ mod tests { let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); assert_eq!(out.len(), value.len().required_space() + value.len()); - assert!(<&[u8]>::decode(&out[..]).is_some()); + assert!(<&[u8]>::decode(&out[..]).is_ok()); assert_eq!( (value.len() + 1).required_space(), value.len().required_space() ); (value.len() + 1).encode_var(&mut out[..]); - assert!(<&[u8]>::decode(&out[..]).is_none()); + assert!(<&[u8]>::decode(&out[..]).is_err()); } #[test] @@ -455,7 +466,7 @@ mod tests { let mut out = vec![0_u8; value.required_bytes()]; value.encode(&mut out[..]); assert_eq!(out.len(), value.len().required_space() + value.len()); - assert!(<&[u8]>::decode(&out[..]).is_some()); + assert!(<&[u8]>::decode(&out[..]).is_ok()); assert_eq!( value.len().required_space(), @@ -480,7 +491,7 @@ mod tests { assert_eq!(read_column.0, vec!["ABC"]); assert_eq!(read_column.1, "ABC".required_bytes()); - assert!(decode_multiple_elements::<&str>(&out[..], 2).is_none()); + assert!(decode_multiple_elements::<&str>(&out[..], 2).is_err()); } #[test] @@ -493,7 +504,7 @@ mod tests { assert_eq!(read_column.0, data.to_vec()); assert_eq!(read_column.1, out.len()); - assert!(decode_multiple_elements::<&str>(&out[..], data.len() + 1).is_none()); + assert!(decode_multiple_elements::<&str>(&out[..], data.len() + 1).is_err()); } #[test] @@ -507,7 +518,7 @@ mod tests { assert_eq!(read_column.1, out.len()); // we remove last element - assert!(decode_multiple_elements::<&str>(&out[..out.len() - 1], data.len()).is_none()); + assert!(decode_multiple_elements::<&str>(&out[..out.len() - 1], data.len()).is_err()); // we change the amount of elements specified in the buffer to be `data[1].len() + 1` assert_eq!( @@ -515,7 +526,7 @@ mod tests { data[1].len().required_space() ); (data[1].len() + 1).encode_var(&mut out[data[0].required_bytes()..]); - assert!(decode_multiple_elements::<&str>(&out[..], data.len()).is_none()); + assert!(decode_multiple_elements::<&str>(&out[..], data.len()).is_err()); } #[test] @@ -526,10 +537,10 @@ mod tests { assert_eq!((s_len - 1_usize).required_space(), s_len.required_space()); (s_len - 1_usize).encode_var(&mut s[..]); assert!( - <&str>::decode(&s[..(s_len - 1_usize + (s_len - 1_usize).required_space())]).is_some() + <&str>::decode(&s[..(s_len - 1_usize + (s_len - 1_usize).required_space())]).is_ok() ); s_len.encode_var(&mut s[..]); - assert!(<&str>::decode(&s[..]).is_none()); + assert!(<&str>::decode(&s[..]).is_err()); } }