Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Big integer PoC #1304

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelogs/unreleased/1304-dark64
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add implementation of big integer arithmetic to stdlib
6 changes: 4 additions & 2 deletions zokrates_analysis/src/expression_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ExpressionValidator {
) -> Result<FieldElementExpression<'ast, T>, Self::Error> {
match e {
// these should have been propagated away
FieldElementExpression::And(_)
FieldElementExpression::IDiv(_)
| FieldElementExpression::Rem(_)
| FieldElementExpression::And(_)
| FieldElementExpression::Or(_)
| FieldElementExpression::Xor(_)
| FieldElementExpression::LeftShift(_)
| FieldElementExpression::RightShift(_) => Err(Error(format!(
"Found non-constant bitwise operation in field element expression `{}`",
"Field element expression `{}` must be a constant expression",
e
))),
FieldElementExpression::Pow(e) => {
Expand Down
6 changes: 6 additions & 0 deletions zokrates_analysis/src/flatten_complex_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,12 @@ fn fold_field_expression<'ast, T: Field>(
typed::FieldElementExpression::Div(e) => {
zir::FieldElementExpression::Div(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::IDiv(e) => {
zir::FieldElementExpression::IDiv(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::Rem(e) => {
zir::FieldElementExpression::Rem(f.fold_binary_expression(statements_buffer, e))
}
typed::FieldElementExpression::Pow(e) => {
zir::FieldElementExpression::Pow(f.fold_binary_expression(statements_buffer, e))
}
Expand Down
87 changes: 87 additions & 0 deletions zokrates_analysis/src/propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,22 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
Ok(UExpression::and(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner())
}
},
UExpressionInner::Or(e) => match (
self.fold_uint_expression(*e.left)?.into_inner(),
self.fold_uint_expression(*e.right)?.into_inner(),
) {
(UExpressionInner::Value(v1), UExpressionInner::Value(v2)) => {
Ok(UExpression::value(v1.value | v2.value))
}
(UExpressionInner::Value(v), e) | (e, UExpressionInner::Value(v))
if v.value == 0 =>
{
Ok(e)
}
(e1, e2) => {
Ok(UExpression::or(e1.annotate(bitwidth), e2.annotate(bitwidth)).into_inner())
}
},
UExpressionInner::Not(e) => {
let e = self.fold_uint_expression(*e.inner)?.into_inner();
match e {
Expand Down Expand Up @@ -939,6 +955,35 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> {
(e1, e2) => Ok(e1 / e2),
}
}
FieldElementExpression::IDiv(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

Ok(match (left, right) {
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
FieldElementExpression::value(
T::try_from(n1.value.to_biguint().div(n2.value.to_biguint())).unwrap(),
)
}
(e1, e2) => FieldElementExpression::idiv(e1, e2),
})
}
FieldElementExpression::Rem(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

Ok(match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(1) => {
FieldElementExpression::value(T::zero())
}
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
FieldElementExpression::value(
T::try_from(n1.value.to_biguint().rem(n2.value.to_biguint())).unwrap(),
)
}
(e1, e2) => e1 % e2,
})
}
FieldElementExpression::Neg(e) => match self.fold_field_expression(*e.inner)? {
FieldElementExpression::Value(n) => {
Ok(FieldElementExpression::value(T::zero() - n.value))
Expand Down Expand Up @@ -1606,6 +1651,48 @@ mod tests {
);
}

#[test]
fn idiv() {
let e = FieldElementExpression::idiv(
FieldElementExpression::value(Bn128Field::from(7)),
FieldElementExpression::value(Bn128Field::from(2)),
);

assert_eq!(
Propagator::default().fold_field_expression(e),
Ok(FieldElementExpression::value(Bn128Field::from(3)))
);
}

#[test]
fn rem() {
let mut propagator = Propagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(5)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(1)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(2)),
FieldElementExpression::value(Bn128Field::from(5)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(2)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(0)))
);
}

#[test]
fn pow() {
let e = FieldElementExpression::pow(
Expand Down
102 changes: 102 additions & 0 deletions zokrates_analysis/src/zir_propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,42 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ZirPropagator<'ast, T> {
(e1, e2) => Ok(FieldElementExpression::div(e1, e2).span(e.span)),
}
}
FieldElementExpression::IDiv(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
Err(Error::DivisionByZero)
}
(e, FieldElementExpression::Value(n)) if n.value == T::from(1) => Ok(e),
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
Ok(FieldElementExpression::value(
T::try_from(n1.value.to_biguint().div(n2.value.to_biguint())).unwrap(),
))
}
(e1, e2) => Ok(FieldElementExpression::idiv(e1, e2).span(e.span)),
}
}
FieldElementExpression::Rem(e) => {
let left = self.fold_field_expression(*e.left)?;
let right = self.fold_field_expression(*e.right)?;

match (left, right) {
(_, FieldElementExpression::Value(n)) if n.value == T::from(0) => {
Err(Error::DivisionByZero)
}
(_, FieldElementExpression::Value(n)) if n.value == T::from(1) => {
Ok(FieldElementExpression::value(T::zero()))
}
(FieldElementExpression::Value(n1), FieldElementExpression::Value(n2)) => {
Ok(FieldElementExpression::value(
T::try_from(n1.value.to_biguint().rem(n2.value.to_biguint())).unwrap(),
))
}
(e1, e2) => Ok(FieldElementExpression::rem(e1, e2).span(e.span)),
}
}
FieldElementExpression::Pow(e) => {
let exponent = self.fold_uint_expression(*e.right)?;
match (self.fold_field_expression(*e.left)?, exponent.into_inner()) {
Expand Down Expand Up @@ -1077,6 +1113,72 @@ mod tests {
);
}

#[test]
fn idiv() {
let mut propagator = ZirPropagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::value(Bn128Field::from(7)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(3)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::identifier("a".into()))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::idiv(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(0)),
)),
Err(Error::DivisionByZero)
);
}

#[test]
fn rem() {
let mut propagator = ZirPropagator::default();

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(5)),
FieldElementExpression::value(Bn128Field::from(2)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(1)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::value(Bn128Field::from(2)),
FieldElementExpression::value(Bn128Field::from(5)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(2)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::rem(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(1)),
)),
Ok(FieldElementExpression::value(Bn128Field::from(0)))
);

assert_eq!(
propagator.fold_field_expression(FieldElementExpression::div(
FieldElementExpression::identifier("a".into()),
FieldElementExpression::value(Bn128Field::from(0)),
)),
Err(Error::DivisionByZero)
);
}

#[test]
fn pow() {
let mut propagator = ZirPropagator::<Bn128Field>::default();
Expand Down
7 changes: 7 additions & 0 deletions zokrates_ast/src/common/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ impl OperatorStr for OpDiv {
const STR: &'static str = "/";
}

#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct OpIDiv;

impl OperatorStr for OpIDiv {
const STR: &'static str = "\\";
}

#[derive(Clone, PartialEq, Debug, Hash, Eq, PartialOrd, Ord)]
pub struct OpRem;

Expand Down
6 changes: 5 additions & 1 deletion zokrates_ast/src/ir/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ir::Parameter;
use crate::ir::ProgIterator;
use crate::ir::Statement;
use crate::ir::Variable;
use crate::Solver;
use std::collections::HashSet;
use zokrates_field::Field;

Expand Down Expand Up @@ -46,7 +47,10 @@ impl<'ast, T: Field> Folder<'ast, T> for UnconstrainedVariableDetector {
&mut self,
d: DirectiveStatement<'ast, T>,
) -> Vec<Statement<'ast, T>> {
self.variables.extend(d.outputs.iter());
match d.solver {
Solver::Zir(_) => {} // we do not check variables introduced by assembly
_ => self.variables.extend(d.outputs.iter()), // this is not necessary, but we keep it as a sanity check
};
vec![Statement::Directive(d)]
}
}
4 changes: 2 additions & 2 deletions zokrates_ast/src/ir/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ impl<'ast, T: Field, I: IntoIterator<Item = Statement<'ast, T>>> ProgIterator<'a
if matches!(s, Statement::Constraint(..)) {
count += 1;
}
let s: Vec<Statement<T>> = solver_indexer
let s: Vec<Statement<T>> = unconstrained_variable_detector
.fold_statement(s)
.into_iter()
.flat_map(|s| unconstrained_variable_detector.fold_statement(s))
.flat_map(|s| solver_indexer.fold_statement(s))
.collect();
for s in s {
serde_cbor::to_writer(&mut w, &s)?;
Expand Down
8 changes: 8 additions & 0 deletions zokrates_ast/src/typed/folder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,14 @@ pub fn fold_field_expression_cases<'ast, T: Field, F: Folder<'ast, T>>(
BinaryOrExpression::Binary(e) => Div(e),
BinaryOrExpression::Expression(u) => u,
},
IDiv(e) => match f.fold_binary_expression(&Type::FieldElement, e) {
BinaryOrExpression::Binary(e) => IDiv(e),
BinaryOrExpression::Expression(u) => u,
},
Rem(e) => match f.fold_binary_expression(&Type::FieldElement, e) {
BinaryOrExpression::Binary(e) => Rem(e),
BinaryOrExpression::Expression(u) => u,
},
Pow(e) => match f.fold_binary_expression(&Type::FieldElement, e) {
BinaryOrExpression::Binary(e) => Pow(e),
BinaryOrExpression::Expression(u) => u,
Expand Down
13 changes: 13 additions & 0 deletions zokrates_ast/src/typed/integer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@ pub enum IntExpression<'ast, T> {
IntExpression<'ast, T>,
>,
),
IDiv(
BinaryExpression<
OpIDiv,
IntExpression<'ast, T>,
IntExpression<'ast, T>,
IntExpression<'ast, T>,
>,
),
Rem(
BinaryExpression<
OpRem,
Expand Down Expand Up @@ -434,6 +442,10 @@ impl<'ast, T> Neg for IntExpression<'ast, T> {
}

impl<'ast, T> IntExpression<'ast, T> {
pub fn idiv(self, other: Self) -> Self {
IntExpression::IDiv(BinaryExpression::new(self, other))
}

pub fn pow(self, other: Self) -> Self {
IntExpression::Pow(BinaryExpression::new(self, other))
}
Expand Down Expand Up @@ -470,6 +482,7 @@ impl<'ast, T: fmt::Display> fmt::Display for IntExpression<'ast, T> {
IntExpression::Pos(ref e) => write!(f, "{}", e),
IntExpression::Neg(ref e) => write!(f, "{}", e),
IntExpression::Div(ref e) => write!(f, "{}", e),
IntExpression::IDiv(ref e) => write!(f, "{}", e),
IntExpression::Rem(ref e) => write!(f, "{}", e),
IntExpression::Pow(ref e) => write!(f, "{}", e),
IntExpression::Select(ref select) => write!(f, "{}", select),
Expand Down
Loading