diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 5bb0fa19d..990443fca 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -207,15 +207,15 @@ pub enum ColumnType { /// Mapped to String #[serde(alias = "VARCHAR", alias = "varchar")] VarChar, - /// Mapped to Curve25519Scalar - #[serde(alias = "SCALAR", alias = "scalar")] - Scalar, /// Mapped to i256 #[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")] Decimal75(Precision, i8), /// Mapped to i64 #[serde(alias = "TIMESTAMP", alias = "timestamp")] TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone), + /// Mapped to Curve25519Scalar + #[serde(alias = "SCALAR", alias = "scalar")] + Scalar, } impl ColumnType { @@ -240,6 +240,44 @@ impl ColumnType { ) } + /// Returns the number of bits in the integer type if it is an integer type. Otherwise, return None. + pub fn to_integer_bits(&self) -> Option { + match self { + ColumnType::SmallInt => Some(16), + ColumnType::Int => Some(32), + ColumnType::BigInt => Some(64), + ColumnType::Int128 => Some(128), + _ => None, + } + } + + /// Returns the ColumnType of the integer type with the given number of bits if it is a valid integer type. + /// + /// Otherwise, return None. + pub fn from_integer_bits(bits: usize) -> Option { + match bits { + 16 => Some(ColumnType::SmallInt), + 32 => Some(ColumnType::Int), + 64 => Some(ColumnType::BigInt), + 128 => Some(ColumnType::Int128), + _ => None, + } + } + + /// Returns the larger integer type of two ColumnTypes if they are both integers. + /// + /// If either of the columns is not an integer, return None. + pub fn max_integer_type(&self, other: &Self) -> Option { + // If either of the columns is not an integer, return None + if !self.is_integer() || !other.is_integer() { + return None; + } + Self::from_integer_bits(std::cmp::max( + self.to_integer_bits().unwrap(), + other.to_integer_bits().unwrap(), + )) + } + /// Returns the precision of a ColumnType if it is converted to a decimal wrapped in Some(). If it can not be converted to a decimal, return None. pub fn precision_value(&self) -> Option { match self { diff --git a/crates/proof-of-sql/src/base/time/timestamp.rs b/crates/proof-of-sql/src/base/time/timestamp.rs index 3ae5d5b8c..38e3e459c 100644 --- a/crates/proof-of-sql/src/base/time/timestamp.rs +++ b/crates/proof-of-sql/src/base/time/timestamp.rs @@ -78,7 +78,7 @@ impl TryFrom<&str> for PoSQLTimeZone { /// Specifies different units of time measurement relative to the Unix epoch. It is essentially /// a wrapper over [arrow::datatypes::TimeUnit] so that we can derive Copy and implement custom traits /// such as bit distribution and Hash. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize, Hash, Ord, PartialOrd)] pub enum PoSQLTimeUnit { /// Represents a time unit of one second. Second, diff --git a/crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs b/crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs new file mode 100644 index 000000000..a8ea95f50 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs @@ -0,0 +1,102 @@ +use super::{ + scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns, + ProvableExpr, ProvableExprPlan, +}; +use crate::{ + base::{ + commitment::Commitment, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + proof::ProofError, + }, + sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder}, +}; +use bumpalo::Bump; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Provable numerical + / - expression +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct AddSubtractExpr { + lhs: Box>, + rhs: Box>, + is_subtract: bool, +} + +impl AddSubtractExpr { + /// Create numerical + / - expression + pub fn new( + lhs: Box>, + rhs: Box>, + is_subtract: bool, + ) -> Self { + Self { + lhs, + rhs, + is_subtract, + } + } +} + +impl ProvableExpr for AddSubtractExpr { + fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { + self.lhs.count(builder)?; + self.rhs.count(builder)?; + Ok(()) + } + + fn data_type(&self) -> ColumnType { + try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type()) + .expect("Failed to add/subtract column types") + } + + fn result_evaluate<'a>( + &self, + table_length: usize, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = + self.lhs.result_evaluate(table_length, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = + self.rhs.result_evaluate(table_length, alloc, accessor); + try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract) + .expect("Failed to add/subtract columns") + } + + #[tracing::instrument( + name = "proofs.sql.ast.not_expr.prover_evaluate", + level = "info", + skip_all + )] + fn prover_evaluate<'a>( + &self, + builder: &mut ProofBuilder<'a, C::Scalar>, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor); + try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract) + .expect("Failed to add/subtract columns") + } + + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &dyn CommitmentAccessor, + ) -> Result { + let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?; + let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?; + let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); + let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); + let res = + scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract) + .expect("Failed to scale and add/subtract"); + Ok(res) + } + + fn get_column_references(&self, columns: &mut HashSet) { + self.lhs.get_column_references(columns); + self.rhs.get_column_references(columns); + } +} diff --git a/crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs b/crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs new file mode 100644 index 000000000..810b2fd51 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs @@ -0,0 +1,181 @@ +use crate::{ + base::{ + commitment::InnerProductProof, + database::{ + make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, + OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor, + TestAccessor, + }, + }, + record_batch, + sql::ast::{ + test_expr::TestExprNode, + test_utility::{add, column, equal, subtract}, + ProvableExpr, ProvableExprPlan, + }, +}; +use arrow::record_batch::RecordBatch; +use bumpalo::Bump; +use curve25519_dalek::ristretto::RistrettoPoint; +use polars::prelude::*; +use rand::{rngs::StdRng, Rng}; +use rand_core::SeedableRng; + +// select results from table_ref where filter_col_l = filter_col_r0 + / - filter_col_r1 +#[allow(clippy::too_many_arguments)] +fn create_test_add_subtract_expr( + table_ref: &str, + results: &[&str], + filter_col_l: &str, + filter_col_r0: &str, + filter_col_r1: &str, + data: RecordBatch, + offset: usize, + is_subtract: bool, +) -> TestExprNode { + let mut accessor = RecordBatchTestAccessor::new_empty(); + let t = table_ref.parse().unwrap(); + accessor.add_table(t, data, offset); + let df_filter = if is_subtract { + polars::prelude::col(filter_col_l).eq(col(filter_col_r0) - col(filter_col_r1)) + } else { + polars::prelude::col(filter_col_l).eq(col(filter_col_r0) + col(filter_col_r1)) + }; + let filter_expr = equal( + column(t, filter_col_l, &accessor), + if is_subtract { + subtract( + column(t, filter_col_r0, &accessor), + column(t, filter_col_r1, &accessor), + ) + } else { + add( + column(t, filter_col_r0, &accessor), + column(t, filter_col_r1, &accessor), + ) + }, + ); + TestExprNode::new(t, results, filter_expr, df_filter, accessor) +} + +#[test] +fn we_can_prove_a_equals_add_query_with_a_single_selected_row() { + let data = record_batch!( + "a" => [123_i64, 456], + "b" => [4_i64, 1], + "c" => [123_i64, 457], + "d" => ["alfa", "gama"] + ); + let test_expr = + create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, false); + let res = test_expr.verify_expr(); + let expected_res = record_batch!( + "a" => [456_i64], + "d" => ["gama"] + ); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_prove_a_equals_subtract_query_with_a_single_selected_row() { + let data = record_batch!( + "a" => [127_i64, 458], + "b" => [4_i64, 1], + "c" => [123_i64, 457], + "d" => ["alfa", "gama"] + ); + let test_expr = + create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, true); + let res = test_expr.verify_expr(); + let expected_res = record_batch!( + "a" => [127_i64, 458], + "d" => ["alfa", "gama"] + ); + assert_eq!(res, expected_res); +} + +fn test_random_tables_with_given_offset(offset: usize) { + let descr = RandomTestAccessorDescriptor { + min_rows: 1, + max_rows: 20, + min_value: -3, + max_value: 3, + }; + let mut rng = StdRng::from_seed([0u8; 32]); + let cols = [ + ("l", ColumnType::BigInt), + ("r0", ColumnType::BigInt), + ("r1", ColumnType::BigInt), + ("varchar", ColumnType::VarChar), + ("integer", ColumnType::BigInt), + ]; + for _ in 0..20 { + let data = make_random_test_accessor_data(&mut rng, &cols, &descr); + let is_subtract = rng.gen::(); + let test_expr = create_test_add_subtract_expr( + "sxt.t", + &["l", "varchar", "integer"], + "l", + "r0", + "r1", + data, + offset, + is_subtract, + ); + let res = test_expr.verify_expr(); + let expected_res = test_expr.query_table(); + assert_eq!(res, expected_res); + } +} + +#[test] +fn we_can_query_random_tables_with_a_zero_offset() { + test_random_tables_with_given_offset(0); +} + +#[test] +fn we_can_query_random_tables_with_a_non_zero_offset() { + test_random_tables_with_given_offset(75); +} + +#[test] +fn we_can_compute_the_correct_output_of_an_add_expr_using_result_evaluate() { + let data = owned_table([ + bigint("a", [123, 456]), + bigint("b", [3, 1]), + bigint("c", [126, 453]), + varchar("d", ["alfa", "gama"]), + ]); + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + let t = "sxt.t".parse().unwrap(); + accessor.add_table(t, data, 0); + let eq_expr: ProvableExprPlan = equal( + column(t, "c", &accessor), + add(column(t, "a", &accessor), column(t, "b", &accessor)), + ); + let alloc = Bump::new(); + let res = eq_expr.result_evaluate(2, &alloc, &accessor); + let expected_res = Column::Boolean(&[true, false]); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_compute_the_correct_output_of_a_subtract_expr_using_result_evaluate() { + let data = owned_table([ + bigint("a", [123, 456]), + bigint("b", [3, 1]), + bigint("c", [126, 455]), + varchar("d", ["alfa", "gama"]), + ]); + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + let t = "sxt.t".parse().unwrap(); + accessor.add_table(t, data, 0); + let eq_expr: ProvableExprPlan = equal( + column(t, "c", &accessor), + subtract(column(t, "a", &accessor), column(t, "b", &accessor)), + ); + let alloc = Bump::new(); + let res = eq_expr.result_evaluate(2, &alloc, &accessor); + let expected_res = Column::Boolean(&[false, true]); + assert_eq!(res, expected_res); +} diff --git a/crates/proof-of-sql/src/sql/ast/comparison_util.rs b/crates/proof-of-sql/src/sql/ast/comparison_util.rs index 382bd2c70..53689d5d6 100644 --- a/crates/proof-of-sql/src/sql/ast/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/ast/comparison_util.rs @@ -72,7 +72,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( ); // Check if the precision is valid let _max_precision = Precision::new(max_precision_value) - .map_err(|_| ConversionError::InvalidPrecision(max_precision_value))?; + .map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?; } unchecked_subtract_impl( alloc, diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index af0bfe89b..cce66e2d7 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -2,6 +2,11 @@ mod filter_result_expr; pub(crate) use filter_result_expr::FilterResultExpr; +mod add_subtract_expr; +pub(crate) use add_subtract_expr::AddSubtractExpr; +#[cfg(all(test, feature = "blitzar"))] +mod add_subtract_expr_test; + mod filter_expr; pub(crate) use filter_expr::FilterExpr; #[cfg(test)] @@ -52,6 +57,11 @@ mod not_expr_test; mod comparison_util; pub(crate) use comparison_util::{scale_and_subtract, scale_and_subtract_eval}; +mod numerical_util; +pub(crate) use numerical_util::{ + scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns, +}; + mod equals_expr; use equals_expr::*; #[cfg(all(test, feature = "blitzar"))] diff --git a/crates/proof-of-sql/src/sql/ast/numerical_util.rs b/crates/proof-of-sql/src/sql/ast/numerical_util.rs new file mode 100644 index 000000000..6ff4b89f5 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/numerical_util.rs @@ -0,0 +1,98 @@ +use crate::{ + base::{ + database::{Column, ColumnType}, + math::decimal::{scale_scalar, Precision}, + scalar::Scalar, + }, + sql::parse::{ConversionError, ConversionResult}, +}; +use bumpalo::Bump; + +// For decimal type manipulation please refer to +// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 + +/// Determine the output type of an add or subtract operation if it is possible +/// to add or subtract the two input types. If the types are not compatible, return +/// an error. +pub(crate) fn try_add_subtract_column_types( + lhs: ColumnType, + rhs: ColumnType, +) -> ConversionResult { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ConversionError::DataTypeMismatch( + lhs.to_string(), + rhs.to_string(), + )); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = lhs.precision_value().unwrap_or(0) as i16; + let right_precision_value = rhs.precision_value().unwrap_or(0) as i16; + let left_scale = lhs.scale().unwrap_or(0); + let right_scale = rhs.scale().unwrap_or(0); + let scale = left_scale.max(right_scale); + let precision_value: i16 = scale as i16 + + (left_precision_value - left_scale as i16) + .max(right_precision_value - right_scale as i16) + + 1_i16; + let precision = u8::try_from(precision_value) + .map_err(|_| ConversionError::InvalidPrecision(precision_value)) + .and_then(|p| { + Precision::new(p).map_err(|_| ConversionError::InvalidPrecision(p as i16)) + })?; + Ok(ColumnType::Decimal75(precision, scale)) + } +} + +/// Add or subtract two columns together. +/// +/// If the columns are not compatible for addition/subtraction, return an error. +pub(crate) fn try_add_subtract_columns<'a, S: Scalar>( + lhs: Column<'a, S>, + rhs: Column<'a, S>, + alloc: &'a Bump, + is_subtract: bool, +) -> ConversionResult> { + let lhs_len = lhs.len(); + let rhs_len = rhs.len(); + if lhs_len != rhs_len { + return Err(ConversionError::DifferentColumnLength(lhs_len, rhs_len)); + } + let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len); + let left_scale = lhs.column_type().scale().unwrap_or(0); + let right_scale = rhs.column_type().scale().unwrap_or(0); + let max_scale = left_scale.max(right_scale); + let lhs_scalar = lhs.to_scalar_with_scaling(max_scale - left_scale); + let rhs_scalar = rhs.to_scalar_with_scaling(max_scale - right_scale); + let res = alloc.alloc_slice_fill_with(lhs_len, |i| { + if is_subtract { + lhs_scalar[i] - rhs_scalar[i] + } else { + lhs_scalar[i] + rhs_scalar[i] + } + }); + Ok(Column::Scalar(res)) +} + +/// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions. +pub(crate) fn scale_and_add_subtract_eval( + lhs_eval: S, + rhs_eval: S, + lhs_scale: i8, + rhs_scale: i8, + is_subtract: bool, +) -> ConversionResult { + let max_scale = lhs_scale.max(rhs_scale); + let scaled_lhs_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)?; + let scaled_rhs_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)?; + if is_subtract { + Ok(scaled_lhs_eval - scaled_rhs_eval) + } else { + Ok(scaled_lhs_eval + scaled_rhs_eval) + } +} diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 5ffaa6162..dcec2c374 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -1,5 +1,6 @@ use super::{ - AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, ProvableExpr, + AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, + ProvableExpr, }; use crate::{ base::{ @@ -34,6 +35,10 @@ pub enum ProvableExprPlan { Equals(EqualsExpr), /// Provable AST expression for an inequality expression Inequality(InequalityExpr), + /// Provable numeric + expression + Add(AddSubtractExpr), + /// Provable numeric - expression + Subtract(AddSubtractExpr), } impl ProvableExprPlan { /// Create column expression @@ -109,6 +114,48 @@ impl ProvableExprPlan { } } + /// Create a new add expression + pub fn try_new_add( + lhs: ProvableExprPlan, + rhs: ProvableExprPlan, + ) -> ConversionResult { + let lhs_datatype = lhs.data_type(); + let rhs_datatype = rhs.data_type(); + if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) { + Err(ConversionError::DataTypeMismatch( + lhs_datatype.to_string(), + rhs_datatype.to_string(), + )) + } else { + Ok(Self::Add(AddSubtractExpr::new( + Box::new(lhs), + Box::new(rhs), + false, + ))) + } + } + + /// Create a new subtract expression + pub fn try_new_subtract( + lhs: ProvableExprPlan, + rhs: ProvableExprPlan, + ) -> ConversionResult { + let lhs_datatype = lhs.data_type(); + let rhs_datatype = rhs.data_type(); + if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) { + Err(ConversionError::DataTypeMismatch( + lhs_datatype.to_string(), + rhs_datatype.to_string(), + )) + } else { + Ok(Self::Subtract(AddSubtractExpr::new( + Box::new(lhs), + Box::new(rhs), + true, + ))) + } + } + /// Check that the plan has the correct data type fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> { if self.data_type() == data_type { @@ -132,12 +179,16 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Literal(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Equals(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Inequality(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Add(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Subtract(expr) => ProvableExpr::::count(expr, builder), } } fn data_type(&self) -> ColumnType { match self { ProvableExprPlan::Column(expr) => expr.data_type(), + ProvableExprPlan::Add(expr) => expr.data_type(), + ProvableExprPlan::Subtract(expr) => expr.data_type(), ProvableExprPlan::Literal(expr) => ProvableExpr::::data_type(expr), ProvableExprPlan::And(_) | ProvableExprPlan::Or(_) @@ -175,6 +226,12 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => { ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) } + ProvableExprPlan::Add(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } + ProvableExprPlan::Subtract(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } } } @@ -206,6 +263,12 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => { ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) } + ProvableExprPlan::Add(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } + ProvableExprPlan::Subtract(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } } } @@ -224,6 +287,8 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Literal(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Equals(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Add(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Subtract(expr) => expr.verifier_evaluate(builder, accessor), } } @@ -244,6 +309,10 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => { ProvableExpr::::get_column_references(expr, columns) } + ProvableExprPlan::Add(expr) => ProvableExpr::::get_column_references(expr, columns), + ProvableExprPlan::Subtract(expr) => { + ProvableExpr::::get_column_references(expr, columns) + } } } } diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index 4ed303123..3bc668638 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -63,6 +63,20 @@ pub fn or( ProvableExprPlan::try_new_or(left, right).unwrap() } +pub fn add( + left: ProvableExprPlan, + right: ProvableExprPlan, +) -> ProvableExprPlan { + ProvableExprPlan::try_new_add(left, right).unwrap() +} + +pub fn subtract( + left: ProvableExprPlan, + right: ProvableExprPlan, +) -> ProvableExprPlan { + ProvableExprPlan::try_new_subtract(left, right).unwrap() +} + pub fn const_bool(val: bool) -> ProvableExprPlan { ProvableExprPlan::new_literal(LiteralValue::Boolean(val)) } diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index 1ffc98fbf..255522b3e 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -59,8 +59,8 @@ pub enum ConversionError { PrecisionParseError(String), #[error("Decimal precision is not valid: {0}")] - /// Decimal precision exceeds the allowed limit - InvalidPrecision(u8), + /// Decimal precision is an integer but exceeds the allowed limit. We use i16 here to include all kinds of invalid precision values. + InvalidPrecision(i16), #[error("Encountered parsing error: {0}")] /// General parsing error diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index 0b08ad800..d33637533 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -73,7 +73,7 @@ impl ProvableExprPlanBuilder<'_> { Literal::Decimal(d) => { let scale = d.scale(); let precision = Precision::new(d.precision()) - .map_err(|_| ConversionError::InvalidPrecision(d.precision()))?; + .map_err(|_| ConversionError::InvalidPrecision(d.precision() as i16))?; Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( precision, scale, @@ -130,13 +130,19 @@ impl ProvableExprPlanBuilder<'_> { let right = self.visit_expr(right); ProvableExprPlan::try_new_inequality(left?, right?, true) } - BinaryOperator::Add - | BinaryOperator::Subtract - | BinaryOperator::Multiply - | BinaryOperator::Division => Err(ConversionError::Unprovable(format!( - "Binary operator {:?} is not supported yet", - op - ))), + BinaryOperator::Add => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_add(left?, right?) + } + BinaryOperator::Subtract => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_subtract(left?, right?) + } + BinaryOperator::Multiply | BinaryOperator::Division => Err( + ConversionError::Unprovable(format!("Binary operator {:?} is not supported", op)), + ), } } } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index e83615da4..770fe1b6c 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -276,6 +276,40 @@ fn we_can_convert_an_ast_with_two_columns() { assert_eq!(ast, expected_ast); } +#[test] +fn we_can_convert_an_ast_with_two_columns_and_arithmetic() { + let t = "sxt.sxt_tab".parse().unwrap(); + let accessor = record_batch_to_accessor( + t, + record_batch!( + "a" => Vec::::new(), + "b" => Vec::::new(), + "c" => Vec::::new(), + ), + 0_usize, + ); + let ast = query_to_provable_ast( + t, + "select a, b from sxt_tab where c = a + b - 1", + &accessor, + ); + let expected_ast = QueryExpr::new( + dense_filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + equal( + column(t, "c", &accessor), + subtract( + add(column(t, "a", &accessor), column(t, "b", &accessor)), + const_bigint(1), + ), + ), + ), + result(&[("a", "a"), ("b", "b")]), + ); + assert_eq!(ast, expected_ast); +} + #[test] fn we_can_parse_all_result_columns_with_select_star() { let t = "sxt.sxt_tab".parse().unwrap(); @@ -717,14 +751,17 @@ fn we_can_parse_order_by_with_multiple_columns() { ); let ast = query_to_provable_ast( t, - "select a, b from sxt_tab where a = 3 order by b desc, a asc", + "select a, b from sxt_tab where a = b + 3 order by b desc, a asc", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( cols_expr_plan(t, &["a", "b"], &accessor), tab(t), - equal(column(t, "a", &accessor), const_bigint(3)), + equal( + column(t, "a", &accessor), + add(column(t, "b", &accessor), const_bigint(3)), + ), ), composite_result(vec![ select(&[pc("a").alias("a"), pc("b").alias("b")]), @@ -1584,15 +1621,32 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr_plan(t, &["a", "b", "f", "h"], &accessor), + vec![ + ( + add(column(t, "a", &accessor), column(t, "b", &accessor)), + "__expr__".parse().unwrap(), + ), + ( + subtract(const_bigint(-77), column(t, "h", &accessor)), + "col".parse().unwrap(), + ), + ( + add(column(t, "a", &accessor), column(t, "f", &accessor)), + "af".parse().unwrap(), + ), + col_expr_plan(t, "a", &accessor), + col_expr_plan(t, "b", &accessor), + col_expr_plan(t, "f", &accessor), + col_expr_plan(t, "h", &accessor), + ], tab(t), const_bool(true), ), composite_result(vec![select(&[ - (pc("a") + pc("b")).alias("__expr__"), + pc("__expr__").alias("__expr__"), (lit_i64(2) * pc("f")).alias("f2"), - ((-77_i64).to_lit() - pc("h")).alias("col"), - (pc("a") + pc("f")).alias("af"), + pc("col").alias("col"), + pc("af").alias("af"), // TODO: add `a / b as a_div_b` result expr once polars properly // supports decimal division without panicking in production // (pc("a") / pc("b")).alias("a_div_b"), diff --git a/crates/proof-of-sql/tests/integration_tests.rs b/crates/proof-of-sql/tests/integration_tests.rs index 1d56894b3..679568c30 100644 --- a/crates/proof-of-sql/tests/integration_tests.rs +++ b/crates/proof-of-sql/tests/integration_tests.rs @@ -12,8 +12,8 @@ use proof_of_sql::{ proof_primitive::dory::{DoryCommitment, DoryEvaluationProof, DoryProverPublicSetup}, record_batch, sql::{ - parse::{ConversionError, QueryExpr}, - proof::QueryProof, + parse::QueryExpr, + proof::{QueryError, QueryProof}, }, }; @@ -160,41 +160,68 @@ fn we_can_prove_a_basic_inequality_query_with_curve25519() { assert_eq!(owned_table_result, expected_result); } -//TODO: Once arithmetic is supported, this test should be updated to use arithmetic. #[test] #[cfg(feature = "blitzar")] -fn we_cannot_prove_a_query_with_arithmetic_in_where_clause_with_curve25519() { +fn we_can_prove_a_query_with_arithmetic_in_where_clause_with_curve25519() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table( "sxt.table".parse().unwrap(), - owned_table([bigint("a", [1, 2, 3]), bigint("b", [1, 0, 2])]), + owned_table([bigint("a", [1, 2, 3]), bigint("b", [4, 1, 2])]), 0, ); - let res_query = QueryExpr::::try_new( + let query = QueryExpr::::try_new( "SELECT * FROM table WHERE b >= a + 1".parse().unwrap(), "sxt".parse().unwrap(), &accessor, - ); - assert!(matches!(res_query, Err(ConversionError::Unprovable(_)))); + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &()); + let owned_table_result = proof + .verify(query.proof_expr(), &accessor, &serialized_result, &()) + .unwrap() + .table; + let owned_table_result: OwnedTable = query + .result() + .transform_results(owned_table_result.try_into().unwrap()) + .unwrap() + .try_into() + .unwrap(); + let expected_result = owned_table([bigint("a", [1]), bigint("b", [4])]); + assert_eq!(owned_table_result, expected_result); } #[test] -fn we_cannot_prove_a_query_with_arithmetic_in_where_clause_with_dory() { +fn we_can_prove_a_query_with_arithmetic_in_where_clause_with_dory() { let dory_prover_setup = DoryProverPublicSetup::rand(4, 3, &mut test_rng()); + let dory_verifier_setup = (&dory_prover_setup).into(); let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup( dory_prover_setup.clone(), ); accessor.add_table( "sxt.table".parse().unwrap(), - owned_table([bigint("a", [1, 2, 3]), bigint("b", [1, 0, 2])]), + owned_table([bigint("a", [1, -1, 3]), bigint("b", [0, 0, 2])]), 0, ); - let res_query = QueryExpr::::try_new( - "SELECT * FROM table WHERE b >= -(a)".parse().unwrap(), + let query = QueryExpr::::try_new( + "SELECT * FROM table WHERE b > 1 - a".parse().unwrap(), "sxt".parse().unwrap(), &accessor, - ); - assert!(matches!(res_query, Err(ConversionError::Unprovable(_)))); + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &dory_prover_setup); + let owned_table_result = proof + .verify( + query.proof_expr(), + &accessor, + &serialized_result, + &dory_verifier_setup, + ) + .unwrap() + .table; + let expected_result = owned_table([bigint("a", [3]), bigint("b", [2])]); + assert_eq!(owned_table_result, expected_result); } #[test] @@ -274,17 +301,17 @@ fn we_can_prove_a_complex_query_with_curve25519() { accessor.add_table( "sxt.table".parse().unwrap(), owned_table([ - bigint("a", [1, 2, 3]), - bigint("b", [1, 0, 1]), - bigint("c", [3, 3, -3]), - bigint("d", [1, 2, 3]), + smallint("a", [1_i16, 2, 3]), + int("b", [1_i32, 0, 1]), + bigint("c", [3_i64, 3, -3]), + bigint("d", [1_i64, 2, 3]), varchar("e", ["d", "e", "f"]), boolean("f", [true, false, false]), ]), 0, ); let query = QueryExpr::try_new( - "SELECT *, 45 as g, (a = b) or f as h FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f" + "SELECT a + b + c + 1 as t, 45 as g, (a = b) or f as h FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f" .parse() .unwrap(), "sxt".parse().unwrap(), @@ -297,16 +324,7 @@ fn we_can_prove_a_complex_query_with_curve25519() { .verify(query.proof_expr(), &accessor, &serialized_result, &()) .unwrap() .table; - let expected_result = owned_table([ - bigint("a", [3]), - bigint("b", [1]), - bigint("c", [-3]), - bigint("d", [3]), - varchar("e", ["f"]), - boolean("f", [false]), - bigint("g", [45]), - boolean("h", [false]), - ]); + let expected_result = owned_table([bigint("t", [2]), bigint("g", [45]), boolean("h", [false])]); assert_eq!(owned_table_result, expected_result); } @@ -331,7 +349,7 @@ fn we_can_prove_a_complex_query_with_dory() { 0, ); let query = QueryExpr::try_new( - "SELECT *, 32 as g, (c >= d) and f as h FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f" + "SELECT a - b + c - d as res, 32 as g, (c >= d) and f as h FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f" .parse() .unwrap(), "sxt".parse().unwrap(), @@ -349,16 +367,8 @@ fn we_can_prove_a_complex_query_with_dory() { ) .unwrap() .table; - let expected_result = owned_table([ - smallint("a", [1_i16]), - int("b", [1]), - bigint("c", [3]), - bigint("d", [1]), - varchar("e", ["d"]), - boolean("f", [true]), - bigint("g", [32]), - boolean("h", [true]), - ]); + let expected_result = + owned_table([bigint("res", [2]), bigint("g", [32]), boolean("h", [true])]); assert_eq!(owned_table_result, expected_result); } @@ -437,3 +447,54 @@ fn we_can_prove_a_basic_group_by_query_with_dory() { ]); assert_eq!(owned_table_result, expected_result); } + +// Overflow checks +#[test] +#[cfg(feature = "blitzar")] +fn we_can_prove_a_query_with_overflow_with_curve25519() { + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + accessor.add_table( + "sxt.table".parse().unwrap(), + owned_table([smallint("a", [i16::MAX]), smallint("b", [1])]), + 0, + ); + let query = QueryExpr::try_new( + "SELECT a + b as c from table".parse().unwrap(), + "sxt".parse().unwrap(), + &accessor, + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &()); + assert!(matches!( + proof.verify(query.proof_expr(), &accessor, &serialized_result, &()), + Err(QueryError::Overflow) + )); +} + +#[test] +fn we_can_prove_a_query_with_overflow_with_dory() { + let dory_prover_setup = DoryProverPublicSetup::rand(4, 3, &mut test_rng()); + let dory_verifier_setup = (&dory_prover_setup).into(); + + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup( + dory_prover_setup.clone(), + ); + accessor.add_table( + "sxt.table".parse().unwrap(), + owned_table([bigint("a", [i64::MIN]), smallint("b", [1])]), + 0, + ); + let query = QueryExpr::try_new( + "SELECT a - b as c from table".parse().unwrap(), + "sxt".parse().unwrap(), + &accessor, + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &dory_prover_setup); + assert!(matches!( + proof.verify(query.proof_expr(), &accessor, &serialized_result, &()), + Err(QueryError::Overflow) + )); +}