Skip to content

Commit

Permalink
feat: add AddSubExpr
Browse files Browse the repository at this point in the history
- add `AddSubExpr` and enable + and - elsewhere
  • Loading branch information
iajoiner committed Jun 18, 2024
1 parent 4e1d994 commit 738e62a
Show file tree
Hide file tree
Showing 13 changed files with 575 additions and 35 deletions.
8 changes: 4 additions & 4 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub const INT128_SCALE: usize = 0;
///
/// See `<https://ignite.apache.org/docs/latest/sql-reference/data-types>` for
/// a description of the native types used by Apache Ignite.
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Deserialize, Copy)]
#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Hash, Serialize, Deserialize, Copy)]
pub enum ColumnType {
/// Mapped to bool
#[serde(alias = "BOOLEAN", alias = "boolean")]
Expand All @@ -188,12 +188,12 @@ pub enum ColumnType {
/// Mapped to String
#[serde(alias = "VARCHAR", alias = "varchar")]
VarChar,
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
/// Mapped to i256
#[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")]
Decimal75(Precision, i8),
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
}

impl ColumnType {
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/math/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
use proof_of_sql_parser::intermediate_decimal::IntermediateDecimal;
use serde::{Deserialize, Deserializer, Serialize};

#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)]
#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Hash, Serialize, Copy)]
/// limit-enforced precision
pub struct Precision(u8);
pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75;
Expand Down
102 changes: 102 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use super::{
scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns,
ProvableExpr, ProvableExprPlan,
};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical + / - expression
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AddSubtractExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
}

impl<C: Commitment> AddSubtractExpr<C> {
/// Create numerical + / - expression
pub fn new(
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
) -> Self {
Self {
lhs,
rhs,
is_subtract,
}
}
}

impl<C: Commitment> ProvableExpr<C> for AddSubtractExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
Ok(())
}

fn data_type(&self) -> ColumnType {
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to add/subtract column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract)
.expect("Failed to add/subtract columns")
}

#[tracing::instrument(
name = "proofs.sql.ast.not_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract)
.expect("Failed to add/subtract columns")
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?;
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res =
scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract)
.expect("Failed to scale and add/subtract");
Ok(res)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
181 changes: 181 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use crate::{
base::{
commitment::InnerProductProof,
database::{
make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType,
OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor,
TestAccessor,
},
},
record_batch,
sql::ast::{
test_expr::TestExprNode,
test_utility::{add, column, equal, subtract},
ProvableExpr, ProvableExprPlan,
},
};
use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use curve25519_dalek::ristretto::RistrettoPoint;
use polars::prelude::*;
use rand::{rngs::StdRng, Rng};
use rand_core::SeedableRng;

// select results from table_ref where filter_col_l = filter_col_r0 + / - filter_col_r1
#[allow(clippy::too_many_arguments)]
fn create_test_add_subtract_expr(
table_ref: &str,
results: &[&str],
filter_col_l: &str,
filter_col_r0: &str,
filter_col_r1: &str,
data: RecordBatch,
offset: usize,
is_subtract: bool,
) -> TestExprNode {
let mut accessor = RecordBatchTestAccessor::new_empty();
let t = table_ref.parse().unwrap();
accessor.add_table(t, data, offset);
let df_filter = if is_subtract {
polars::prelude::col(filter_col_l).eq(col(filter_col_r0) - col(filter_col_r1))
} else {
polars::prelude::col(filter_col_l).eq(col(filter_col_r0) + col(filter_col_r1))
};
let filter_expr = equal(
column(t, filter_col_l, &accessor),
if is_subtract {
subtract(
column(t, filter_col_r0, &accessor),
column(t, filter_col_r1, &accessor),
)
} else {
add(
column(t, filter_col_r0, &accessor),
column(t, filter_col_r1, &accessor),
)
},
);
TestExprNode::new(t, results, filter_expr, df_filter, accessor)
}

#[test]
fn we_can_prove_a_equals_add_query_with_a_single_selected_row() {
let data = record_batch!(
"a" => [123_i64, 456],
"b" => [4_i64, 1],
"c" => [123_i64, 457],
"d" => ["alfa", "gama"]
);
let test_expr =
create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, false);
let res = test_expr.verify_expr();
let expected_res = record_batch!(
"a" => [456_i64],
"d" => ["gama"]
);
assert_eq!(res, expected_res);
}

#[test]
fn we_can_prove_a_equals_subtract_query_with_a_single_selected_row() {
let data = record_batch!(
"a" => [127_i64, 458],
"b" => [4_i64, 1],
"c" => [123_i64, 457],
"d" => ["alfa", "gama"]
);
let test_expr =
create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, true);
let res = test_expr.verify_expr();
let expected_res = record_batch!(
"a" => [127_i64, 458],
"d" => ["alfa", "gama"]
);
assert_eq!(res, expected_res);
}

fn test_random_tables_with_given_offset(offset: usize) {
let descr = RandomTestAccessorDescriptor {
min_rows: 1,
max_rows: 20,
min_value: -3,
max_value: 3,
};
let mut rng = StdRng::from_seed([0u8; 32]);
let cols = [
("l", ColumnType::BigInt),
("r0", ColumnType::BigInt),
("r1", ColumnType::BigInt),
("varchar", ColumnType::VarChar),
("integer", ColumnType::BigInt),
];
for _ in 0..20 {
let data = make_random_test_accessor_data(&mut rng, &cols, &descr);
let is_subtract = rng.gen::<bool>();
let test_expr = create_test_add_subtract_expr(
"sxt.t",
&["l", "varchar", "integer"],
"l",
"r0",
"r1",
data,
offset,
is_subtract,
);
let res = test_expr.verify_expr();
let expected_res = test_expr.query_table();
assert_eq!(res, expected_res);
}
}

#[test]
fn we_can_query_random_tables_with_a_zero_offset() {
test_random_tables_with_given_offset(0);
}

#[test]
fn we_can_query_random_tables_with_a_non_zero_offset() {
test_random_tables_with_given_offset(75);
}

#[test]
fn we_can_compute_the_correct_output_of_an_add_expr_using_result_evaluate() {
let data = owned_table([
bigint("a", [123, 456]),
bigint("b", [3, 1]),
bigint("c", [126, 453]),
varchar("d", ["alfa", "gama"]),
]);
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let t = "sxt.t".parse().unwrap();
accessor.add_table(t, data, 0);
let eq_expr: ProvableExprPlan<RistrettoPoint> = equal(
column(t, "c", &accessor),
add(column(t, "a", &accessor), column(t, "b", &accessor)),
);
let alloc = Bump::new();
let res = eq_expr.result_evaluate(2, &alloc, &accessor);
let expected_res = Column::Boolean(&[true, false]);
assert_eq!(res, expected_res);
}

#[test]
fn we_can_compute_the_correct_output_of_a_subtract_expr_using_result_evaluate() {
let data = owned_table([
bigint("a", [123, 456]),
bigint("b", [3, 1]),
bigint("c", [126, 455]),
varchar("d", ["alfa", "gama"]),
]);
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let t = "sxt.t".parse().unwrap();
accessor.add_table(t, data, 0);
let eq_expr: ProvableExprPlan<RistrettoPoint> = equal(
column(t, "c", &accessor),
subtract(column(t, "a", &accessor), column(t, "b", &accessor)),
);
let alloc = Bump::new();
let res = eq_expr.result_evaluate(2, &alloc, &accessor);
let expected_res = Column::Boolean(&[false, true]);
assert_eq!(res, expected_res);
}
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/comparison_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value)
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value))?;
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?;
}
unchecked_subtract_impl(
alloc,
Expand Down
10 changes: 10 additions & 0 deletions crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
mod filter_result_expr;
pub(crate) use filter_result_expr::FilterResultExpr;

mod add_subtract_expr;
pub(crate) use add_subtract_expr::AddSubtractExpr;
#[cfg(all(test, feature = "blitzar"))]
mod add_subtract_expr_test;

mod filter_expr;
pub(crate) use filter_expr::FilterExpr;
#[cfg(test)]
Expand Down Expand Up @@ -52,6 +57,11 @@ mod not_expr_test;
mod comparison_util;
pub(crate) use comparison_util::{scale_and_subtract, scale_and_subtract_eval};

mod numerical_util;
pub(crate) use numerical_util::{
scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns,
};

mod equals_expr;
use equals_expr::*;
#[cfg(all(test, feature = "blitzar"))]
Expand Down
Loading

0 comments on commit 738e62a

Please sign in to comment.