From c77e7ae5c3caf5dd99a7d0eca4284407c1993767 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Feb 2023 20:05:30 +0100 Subject: [PATCH 01/19] show help when running "zokrates mpc" --- changelogs/unreleased/1275-dark64 | 1 + zokrates_cli/src/ops/mpc/mod.rs | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 changelogs/unreleased/1275-dark64 diff --git a/changelogs/unreleased/1275-dark64 b/changelogs/unreleased/1275-dark64 new file mode 100644 index 000000000..506dce652 --- /dev/null +++ b/changelogs/unreleased/1275-dark64 @@ -0,0 +1 @@ +Show help when running `zokrates mpc` \ No newline at end of file diff --git a/zokrates_cli/src/ops/mpc/mod.rs b/zokrates_cli/src/ops/mpc/mod.rs index 000d82023..dd5913ed1 100644 --- a/zokrates_cli/src/ops/mpc/mod.rs +++ b/zokrates_cli/src/ops/mpc/mod.rs @@ -1,4 +1,4 @@ -use clap::{App, ArgMatches, SubCommand}; +use clap::{App, ArgMatches, SubCommand, AppSettings}; pub mod beacon; pub mod contribute; @@ -9,6 +9,7 @@ pub mod verify; pub fn subcommand() -> App<'static, 'static> { SubCommand::with_name("mpc") .about("Multi-party computation (MPC) protocol") + .setting(AppSettings::SubcommandRequiredElseHelp) .subcommands(vec![ init::subcommand().display_order(1), contribute::subcommand().display_order(2), @@ -25,6 +26,6 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { ("beacon", Some(sub_matches)) => beacon::exec(sub_matches), ("verify", Some(sub_matches)) => verify::exec(sub_matches), ("export", Some(sub_matches)) => export::exec(sub_matches), - _ => unreachable!(), + _ => unreachable!() } } From aca048eda55ec5a8637f09bff3305c75f397a126 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 6 Feb 2023 20:07:58 +0100 Subject: [PATCH 02/19] fmt --- zokrates_cli/src/ops/mpc/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zokrates_cli/src/ops/mpc/mod.rs b/zokrates_cli/src/ops/mpc/mod.rs index dd5913ed1..b01a1ddc9 100644 --- a/zokrates_cli/src/ops/mpc/mod.rs +++ b/zokrates_cli/src/ops/mpc/mod.rs @@ -1,4 +1,4 @@ -use clap::{App, ArgMatches, SubCommand, AppSettings}; +use clap::{App, AppSettings, ArgMatches, SubCommand}; pub mod beacon; pub mod contribute; @@ -26,6 +26,6 @@ pub fn exec(sub_matches: &ArgMatches) -> Result<(), String> { ("beacon", Some(sub_matches)) => beacon::exec(sub_matches), ("verify", Some(sub_matches)) => verify::exec(sub_matches), ("export", Some(sub_matches)) => export::exec(sub_matches), - _ => unreachable!() + _ => unreachable!(), } } From 86d50e79d8cf253f0c4f35744a4fca023b1d1f75 Mon Sep 17 00:00:00 2001 From: dark64 Date: Wed, 8 Feb 2023 19:24:50 +0100 Subject: [PATCH 03/19] fix setup ffi in zokrates-js --- zokrates_js/index.js | 16 +++++++++++-- zokrates_js/package-lock.json | 4 ++-- zokrates_js/src/lib.rs | 42 ++++++++++++++++++++++++++++++----- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/zokrates_js/index.js b/zokrates_js/index.js index a6d689f8e..55a7dbda6 100644 --- a/zokrates_js/index.js +++ b/zokrates_js/index.js @@ -68,13 +68,25 @@ const initialize = async () => { return result; }, setup: (program, entropy, options) => { - return wasmExports.setup(program, entropy, options); + const ptr = wasmExports.setup(program, entropy, options); + const result = { + vk: ptr.vk(), + pk: ptr.pk(), + }; + ptr.free(); + return result; }, universalSetup: (curve, size, entropy) => { return wasmExports.universal_setup(curve, size, entropy); }, setupWithSrs: (srs, program, options) => { - return wasmExports.setup_with_srs(srs, program, options); + const ptr = wasmExports.setup_with_srs(srs, program, options); + const result = { + vk: ptr.vk(), + pk: ptr.pk(), + }; + ptr.free(); + return result; }, generateProof: (program, witness, provingKey, entropy, options) => { return wasmExports.generate_proof( diff --git a/zokrates_js/package-lock.json b/zokrates_js/package-lock.json index 00ecd023a..0821df1a2 100644 --- a/zokrates_js/package-lock.json +++ b/zokrates_js/package-lock.json @@ -1,12 +1,12 @@ { "name": "zokrates-js", - "version": "1.1.4", + "version": "1.1.5", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "zokrates-js", - "version": "1.1.4", + "version": "1.1.5", "license": "GPLv3", "dependencies": { "pako": "^2.1.0" diff --git a/zokrates_js/src/lib.rs b/zokrates_js/src/lib.rs index 69513ea00..2b4dccebc 100644 --- a/zokrates_js/src/lib.rs +++ b/zokrates_js/src/lib.rs @@ -53,6 +53,7 @@ impl CompilationResult { arr.copy_from(&self.program); arr } + pub fn abi(&self) -> JsValue { JsValue::from_serde(&self.abi).unwrap() } @@ -88,9 +89,11 @@ impl ComputationResult { pub fn witness(&self) -> JsValue { JsValue::from_str(&self.witness) } + pub fn output(&self) -> JsValue { JsValue::from_str(&self.output) } + pub fn snarkjs_witness(&self) -> Option { self.snarkjs_witness.as_ref().map(|w| { let arr = js_sys::Uint8Array::new_with_length(w.len() as u32); @@ -100,6 +103,25 @@ impl ComputationResult { } } +#[wasm_bindgen] +pub struct Keypair { + vk: JsValue, + pk: Vec, +} + +#[wasm_bindgen] +impl Keypair { + pub fn vk(&self) -> JsValue { + self.vk.to_owned() + } + + pub fn pk(&self) -> js_sys::Uint8Array { + let arr = js_sys::Uint8Array::new_with_length(self.pk.len() as u32); + arr.copy_from(&self.pk); + arr + } +} + pub struct JsResolver<'a> { callback: &'a js_sys::Function, } @@ -204,6 +226,7 @@ impl<'a> Write for LogWriter<'a> { fn write(&mut self, buf: &[u8]) -> std::io::Result { self.buf.write(buf) } + fn flush(&mut self) -> std::io::Result<()> { self.callback .call1( @@ -352,10 +375,13 @@ mod internal { >( program: ir::Prog, rng: &mut R, - ) -> JsValue { + ) -> Keypair { let keypair = B::setup(program, rng); let tagged_keypair = TaggedKeypair::::new(keypair); - JsValue::from_serde(&tagged_keypair).unwrap() + Keypair { + vk: JsValue::from_serde(&tagged_keypair.vk).unwrap(), + pk: tagged_keypair.pk, + } } pub fn setup_universal< @@ -367,9 +393,13 @@ mod internal { >( srs: &[u8], program: ir::ProgIterator<'a, T, I>, - ) -> Result { + ) -> Result { let keypair = B::setup(srs.to_vec(), program).map_err(|e| JsValue::from_str(&e))?; - Ok(JsValue::from_serde(&TaggedKeypair::::new(keypair)).unwrap()) + let tagged_keypair = TaggedKeypair::::new(keypair); + Ok(Keypair { + vk: JsValue::from_serde(&tagged_keypair.vk).unwrap(), + pk: tagged_keypair.pk, + }) } pub fn universal_setup_of_size< @@ -528,7 +558,7 @@ pub fn export_solidity_verifier(vk: JsValue) -> Result { } #[wasm_bindgen] -pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result { +pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result { let options: serde_json::Value = options.into_serde().unwrap(); let backend = BackendParameter::try_from( @@ -597,7 +627,7 @@ pub fn setup(program: &[u8], entropy: JsValue, options: JsValue) -> Result Result { +pub fn setup_with_srs(srs: &[u8], program: &[u8], options: JsValue) -> Result { let options: serde_json::Value = options.into_serde().unwrap(); let scheme = SchemeParameter::try_from( From 5ecfe5ea0dd35845a34c900132098afcc3e03576 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 14 Feb 2023 15:36:33 +0100 Subject: [PATCH 04/19] pass ark and bellman features to zokrates_analysis --- zokrates_interpreter/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 77aaff662..41cdfa353 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -5,8 +5,8 @@ edition = "2021" [features] default = ["bellman", "ark"] -bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman", "zokrates_ast/bellman"] -ark = ["ark-bls12-377", "zokrates_embed/ark", "zokrates_ast/ark"] +bellman = ["zokrates_field/bellman", "pairing_ce", "zokrates_embed/bellman", "zokrates_ast/bellman", "zokrates_analysis/bellman"] +ark = ["ark-bls12-377", "zokrates_embed/ark", "zokrates_ast/ark", "zokrates_analysis/ark"] [dependencies] zokrates_field = { version = "0.5", path = "../zokrates_field", default-features = false } From fb45889147a0c90c77eca83c0f4a469e75ee0a5d Mon Sep 17 00:00:00 2001 From: Darko Macesic Date: Wed, 15 Feb 2023 19:36:18 +0100 Subject: [PATCH 05/19] add changelog --- changelogs/unreleased/1277-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1277-dark64 diff --git a/changelogs/unreleased/1277-dark64 b/changelogs/unreleased/1277-dark64 new file mode 100644 index 000000000..e94ff6a55 --- /dev/null +++ b/changelogs/unreleased/1277-dark64 @@ -0,0 +1 @@ +Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair From a2b335cf8e2d20a3f283cee11ebf8c711253ddbf Mon Sep 17 00:00:00 2001 From: schaeff Date: Sun, 19 Feb 2023 20:45:59 +0100 Subject: [PATCH 06/19] wip --- zokrates_analysis/src/propagation.rs | 13 +- .../src/reducer/constants_reader.rs | 45 +- .../src/reducer/constants_writer.rs | 6 +- zokrates_analysis/src/reducer/inline.rs | 105 +-- zokrates_analysis/src/reducer/mod.rs | 656 ++++++++++-------- zokrates_analysis/src/reducer/shallow_ssa.rs | 82 +-- zokrates_ast/src/typed/folder.rs | 13 +- zokrates_ast/src/typed/identifier.rs | 59 +- zokrates_ast/src/typed/mod.rs | 42 ++ zokrates_ast/src/typed/result_folder.rs | 13 +- zokrates_ast/src/typed/types.rs | 9 +- zokrates_core/src/imports.rs | 1 + zokrates_core/src/semantics.rs | 2 +- 13 files changed, 623 insertions(+), 423 deletions(-) diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index b7e5c0a17..b7701dc01 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -45,7 +45,7 @@ impl fmt::Display for Error { } #[derive(Debug)] -pub struct Propagator<'ast, 'a, T: Field> { +pub struct Propagator<'ast, 'a, T> { // constants keeps track of constant expressions // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is constants: &'a mut Constants<'ast, T>, @@ -317,12 +317,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { } }; + // particular case of `lhs = rhs` + if TypedExpression::from(assignee.clone()) == expr { + return Ok(vec![]); + } + if expr.is_constant() { match assignee { TypedAssignee::Identifier(var) => { let expr = expr.into_canonical_constant(); - assert!(self.constants.insert(var.id, expr).is_none()); + assert!( + self.constants.insert(var.clone().id, expr).is_none(), + "{}", + var + ); Ok(vec![]) } diff --git a/zokrates_analysis/src/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs index 4ee0d1359..f0ec252e8 100644 --- a/zokrates_analysis/src/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -2,10 +2,11 @@ use crate::reducer::ConstantDefinitions; use zokrates_ast::typed::{ - folder::*, ArrayExpression, ArrayExpressionInner, ArrayType, BooleanExpression, CoreIdentifier, - DeclarationConstant, Expr, FieldElementExpression, Id, Identifier, IdentifierExpression, - StructExpression, StructExpressionInner, StructType, TupleExpression, TupleExpressionInner, - TupleType, TypedProgram, TypedSymbolDeclaration, UBitwidth, UExpression, UExpressionInner, + folder::*, identifier::FrameIdentifier, ArrayExpression, ArrayExpressionInner, ArrayType, + BooleanExpression, CoreIdentifier, DeclarationConstant, Expr, FieldElementExpression, Id, + Identifier, IdentifierExpression, StructExpression, StructExpressionInner, StructType, + TupleExpression, TupleExpressionInner, TupleType, TypedProgram, TypedSymbolDeclaration, + UBitwidth, UExpression, UExpressionInner, }; use zokrates_field::Field; @@ -61,7 +62,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { FieldElementExpression::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame, + }, version, }, .. @@ -86,7 +91,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { BooleanExpression::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame, + }, version, }, .. @@ -112,7 +121,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { UExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame, + }, version, }, .. @@ -136,7 +149,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { ArrayExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame, + }, version, }, .. @@ -162,7 +179,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { TupleExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame, + }, version, }, .. @@ -188,7 +209,11 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { StructExpressionInner::Identifier(IdentifierExpression { id: Identifier { - id: CoreIdentifier::Constant(c), + id: + FrameIdentifier { + id: CoreIdentifier::Constant(c), + frame, + }, version, }, .. diff --git a/zokrates_analysis/src/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs index d4e03d3d4..d38daf150 100644 --- a/zokrates_analysis/src/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -118,11 +118,7 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for ConstantsWriter<'ast, T> { signature: DeclarationSignature::new().output(c.ty.clone()), }; - let mut inlined_wrapper = reduce_function( - wrapper, - ConcreteGenericsAssignment::default(), - &self.program, - )?; + let mut inlined_wrapper = reduce_function(wrapper, &self.program)?; if let TypedStatement::Return(expression) = inlined_wrapper.statements.pop().unwrap() diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 31f237e82..f1a7229e6 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -26,7 +26,6 @@ // - The body of the function is in SSA form // - The return value(s) are assigned to internal variables -use crate::reducer::Output; use crate::reducer::ShallowTransformer; use crate::reducer::Versions; @@ -34,6 +33,8 @@ use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType}; use zokrates_ast::typed::CoreIdentifier; use zokrates_ast::typed::Identifier; +use zokrates_ast::typed::TypedAssignee; +use zokrates_ast::typed::UBitwidth; use zokrates_ast::typed::{ ConcreteFunctionKey, ConcreteSignature, ConcreteVariable, DeclarationFunctionKey, Expr, Signature, Type, TypedExpression, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, @@ -58,7 +59,7 @@ pub enum InlineError<'ast, T> { } fn get_canonical_function<'ast, T: Field>( - function_key: DeclarationFunctionKey<'ast, T>, + function_key: &DeclarationFunctionKey<'ast, T>, program: &TypedProgram<'ast, T>, ) -> TypedFunctionSymbolDeclaration<'ast, T> { let s = program @@ -66,27 +67,32 @@ fn get_canonical_function<'ast, T: Field>( .get(&function_key.module) .unwrap() .functions_iter() - .find(|d| d.key == function_key) + .find(|d| d.key == *function_key) .unwrap(); match &s.symbol { - TypedFunctionSymbol::There(key) => get_canonical_function(key.clone(), program), + TypedFunctionSymbol::There(key) => get_canonical_function(key, program), _ => s.clone(), } } type InlineResult<'ast, T> = Result< - Output<(Vec>, TypedExpression<'ast, T>), Vec>>, + ( + Vec>, + Vec>, + Vec>, + Vec>, + TypedExpression<'ast, T>, + ), InlineError<'ast, T>, >; pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( - k: DeclarationFunctionKey<'ast, T>, + k: &DeclarationFunctionKey<'ast, T>, generics: Vec>>, arguments: Vec>, output: &E::Ty, program: &TypedProgram<'ast, T>, - versions: &'a mut Versions<'ast>, ) -> InlineResult<'ast, T> { use zokrates_ast::typed::Typed; let output_type = output.clone().into_type(); @@ -124,7 +130,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( Ok(s) => s, Err(_) => { return Err(InlineError::NonConstant( - k, + k.clone(), generics, arguments, output_type, @@ -132,7 +138,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( } }; - let decl = get_canonical_function(k.clone(), program); + let decl = get_canonical_function(&k, program); // get an assignment of generics for this call site let assignment: ConcreteGenericsAssignment<'ast> = k @@ -162,23 +168,35 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); - let (ssa_f, incomplete_data) = match ShallowTransformer::transform(f, &assignment, versions) { - Output::Complete(v) => (v, None), - Output::Incomplete(statements, for_loop_versions) => (statements, Some(for_loop_versions)), - }; + // let ssa_f = ShallowTransformer::transform(f, &assignment, versions); - let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone()); + // let ssa_f = f; + + // let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone()); + + let generics_bindings: Vec<_> = assignment + .0 + .into_iter() + .map(|(identifier, value)| { + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::uint( + CoreIdentifier::from(identifier), + UBitwidth::B32, + )), + TypedExpression::from(UExpression::from(value)).into(), + ) + }) + .collect(); - let input_bindings: Vec> = ssa_f + let input_variables: Vec> = f .arguments .into_iter() .zip(inferred_signature.inputs.clone()) .map(|(p, t)| ConcreteVariable::new(p.id.id, t, false)) - .zip(arguments.clone()) - .map(|(v, a)| TypedStatement::definition(Variable::from(v).into(), a)) + .map(|v| Variable::from(v)) .collect(); - let (statements, mut returns): (Vec<_>, Vec<_>) = ssa_f + let (statements, mut returns): (Vec<_>, Vec<_>) = f .statements .into_iter() .partition(|s| !matches!(s, TypedStatement::Return(..))); @@ -190,31 +208,28 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( _ => unreachable!(), }; - let v: ConcreteVariable<'ast> = ConcreteVariable::new( - Identifier::from(CoreIdentifier::Call(0)).version( - *versions - .entry(CoreIdentifier::Call(0)) - .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0), - ), - *inferred_signature.output.clone(), - false, - ); - - let expression = TypedExpression::from(Variable::from(v.clone())); - - let output_binding = TypedStatement::definition(Variable::from(v).into(), return_expression); - - let pop_log = TypedStatement::PopCallLog; - - let statements: Vec<_> = std::iter::once(call_log) - .chain(input_bindings) - .chain(statements) - .chain(std::iter::once(output_binding)) - .chain(std::iter::once(pop_log)) - .collect(); - - Ok(incomplete_data - .map(|d| Output::Incomplete((statements.clone(), expression.clone()), d)) - .unwrap_or_else(|| Output::Complete((statements, expression)))) + // let v: ConcreteVariable<'ast> = ConcreteVariable::new( + // Identifier::from(CoreIdentifier::Call(0)).version( + // *versions + // .entry(CoreIdentifier::Call(0)) + // .and_modify(|e| *e += 1) // if it was already declared, we increment + // .or_insert(0), + // ), + // *inferred_signature.output.clone(), + // false, + // ); + + // let expression = TypedExpression::from(Variable::from(v.clone())); + + // let output_binding = TypedStatement::definition(Variable::from(v).into(), return_expression); + + // let pop_log = TypedStatement::PopCallLog; + + Ok(( + input_variables, + arguments, + generics_bindings, + statements, + return_expression, + )) } diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index ea6feb536..ceae5e9ea 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -18,9 +18,11 @@ mod shallow_ssa; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; +use zokrates_ast::typed::identifier::FrameIdentifier; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::types::ConcreteGenericsAssignment; use zokrates_ast::typed::types::GGenericsAssignment; +use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; @@ -31,6 +33,7 @@ use zokrates_ast::typed::{ TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, }; +use zokrates_ast::zir::result_folder::fold_assembly_statement; use zokrates_field::Field; use self::constants_writer::ConstantsWriter; @@ -47,13 +50,21 @@ pub type ConstantDefinitions<'ast, T> = HashMap, TypedExpression<'ast, T>>; // An SSA version map, giving access to the latest version number for each identifier -pub type Versions<'ast> = HashMap, usize>; +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Versions<'ast> { + map: HashMap, usize>>, + frame: usize, +} -// A container to represent whether more treatment must be applied to the function -#[derive(Debug, PartialEq, Eq)] -pub enum Output { - Complete(U), - Incomplete(U, V), +impl<'ast> Versions<'ast> { + fn insert_in_frame( + &mut self, + id: CoreIdentifier<'ast>, + version: usize, + frame: usize, + ) -> Option { + self.map.entry(frame).or_default().insert(id, version) + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -84,125 +95,31 @@ impl fmt::Display for Error { } } -#[derive(Debug, Default)] -struct Substitutions<'ast>(HashMap, HashMap>); - -impl<'ast> Substitutions<'ast> { - // create an equivalent substitution map where all paths - // are of length 1 - fn canonicalize(self) -> Self { - Substitutions( - self.0 - .into_iter() - .map(|(id, sub)| (id, Self::canonicalize_sub(sub))) - .collect(), - ) - } - - // canonicalize substitutions for a given id - fn canonicalize_sub(sub: HashMap) -> HashMap { - fn add_to_cache( - sub: &HashMap, - cache: HashMap, - k: usize, - ) -> HashMap { - match cache.contains_key(&k) { - // `k` is already in the cache, no changes to the cache - true => cache, - _ => match sub.get(&k) { - // `k` does not point to anything, no changes to the cache - None => cache, - // `k` points to some `v - Some(v) => { - // add `v` to the cache - let cache = add_to_cache(sub, cache, *v); - // `k` points to what `v` points to, or to `v` - let v = cache.get(v).cloned().unwrap_or(*v); - let mut cache = cache; - cache.insert(k, v); - cache - } - }, - } - } - - sub.keys() - .fold(HashMap::new(), |cache, k| add_to_cache(&sub, cache, *k)) - } -} - -struct Sub<'a, 'ast> { - substitutions: &'a Substitutions<'ast>, -} - -impl<'a, 'ast> Sub<'a, 'ast> { - fn new(substitutions: &'a Substitutions<'ast>) -> Self { - Self { substitutions } - } -} - -impl<'a, 'ast, T: Field> Folder<'ast, T> for Sub<'a, 'ast> { - fn fold_name(&mut self, id: Identifier<'ast>) -> Identifier<'ast> { - let version = self - .substitutions - .0 - .get(&id.id) - .map(|sub| sub.get(&id.version).cloned().unwrap_or(id.version)) - .unwrap_or(id.version); - id.version(version) - } -} - -fn register<'ast>( - substitutions: &mut Substitutions<'ast>, - substitute: &Versions<'ast>, - with: &Versions<'ast>, -) { - for (id, key, value) in substitute - .iter() - .filter_map(|(id, version)| with.get(id).map(|to| (id, version, to))) - .filter(|(_, key, value)| key != value) - { - let sub = substitutions.0.entry(id.clone()).or_default(); - - // redirect `k` to `v`, unless `v` is already redirected to `v0`, in which case we redirect to `v0` - - sub.insert(*key, *sub.get(value).unwrap_or(value)); - } -} - #[derive(Debug)] struct Reducer<'ast, 'a, T> { + propagator: Propagator<'ast, 'a, T>, statement_buffer: Vec>, - for_loop_versions: Vec>, - for_loop_versions_after: Vec>, - program: &'a TypedProgram<'ast, T>, + latest_frame: usize, versions: &'a mut Versions<'ast>, - substitutions: &'a mut Substitutions<'ast>, - complete: bool, + program: &'a TypedProgram<'ast, T>, } impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { fn new( program: &'a TypedProgram<'ast, T>, versions: &'a mut Versions<'ast>, - substitutions: &'a mut Substitutions<'ast>, - for_loop_versions: Vec>, + constants: &'a mut Constants<'ast, T>, ) -> Self { - // we reverse the vector as it's cheaper to `pop` than to take from - // the head - let mut for_loop_versions = for_loop_versions; - - for_loop_versions.reverse(); + // println!("create reducer with"); + // println!("{} versions", versions.len()); + // println!("{} constants", constants.len()); Reducer { + propagator: Propagator::with_constants(constants), statement_buffer: vec![], - for_loop_versions_after: vec![], - for_loop_versions, - substitutions, + latest_frame: 0, program, versions, - complete: true, } } } @@ -210,6 +127,15 @@ impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { type Error = Error; + fn fold_parameter( + &mut self, + p: DeclarationParameter<'ast, T>, + ) -> Result, Self::Error> { + let id = p.id.id.id.id.clone(); + assert!(self.versions.insert_in_frame(id, 0, 0).is_none()); + Ok(p) + } + fn fold_function_call_expression< E: Id<'ast, T> + From> + Expr<'ast, T> + FunctionCall<'ast, T>, >( @@ -229,46 +155,104 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .map(|e| self.fold_expression(e)) .collect::>()?; - let res = inline_call::<_, E>( - e.function_key.clone(), - generics, - arguments, - ty, - self.program, - self.versions, - ); + // back up the current frame + let frame_backup = self.versions.frame; + + // create a new frame + self.latest_frame += 1; + + // point the versions to this frame + self.versions.frame = self.latest_frame; + self.versions + .map + .insert(self.versions.frame, Default::default()); + + // println!("FRAME READY TO INLINE {}", self.versions.frame); + + // println!("GENERICS {:?}", generics); + + let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); match res { - Ok(Output::Complete((statements, expression))) => { - self.complete &= true; + Ok((input_variables, arguments, generics_bindings, statements, expression)) => { + // println!("FRAME BEFORE REDUCING INLINE RESULT {}", self.versions.frame); + + let mut transformer = ShallowTransformer::with_versions(self.versions); + let propagator = &mut self.propagator; + + // println!("BINDINGS BEFORE {}", generics_bindings[0]); + + let generics_bindings: Vec<_> = generics_bindings + .into_iter() + .flat_map(|s| transformer.fold_statement(s)) + .flat_map(|s| propagator.fold_statement(s).unwrap()) + .collect(); + + // println!("{:#?}", propagator); + + self.statement_buffer.extend(generics_bindings); + + // the lhs is from the inner call frame, the rhs is from the outer one, so only fld the lhs + let input_bindings: Vec<_> = input_variables + .into_iter() + .zip(arguments) + .map(|(v, a)| { + TypedStatement::definition(transformer.fold_assignee(v.into()), a) + }) + .collect(); + + let input_bindings: Vec<_> = input_bindings + .into_iter() + .flat_map(|s| propagator.fold_statement(s).unwrap()) + .collect(); + + self.statement_buffer.extend(input_bindings); + + let statements: Vec<_> = statements + .into_iter() + .flat_map(|s| self.fold_statement(s).unwrap()) + .collect(); + + // println!("FRAME READY TO SSA {}", self.versions.frame); + + let mut transformer = ShallowTransformer::with_versions(self.versions); + let propagator = &mut self.propagator; + self.statement_buffer.extend(statements); + + // println!("call result {}", expression); + + let expression = transformer.fold_expression(expression); + + let expression = propagator.fold_expression(expression).unwrap(); + + // println!("call result reduced {}", expression); + + // clean versions + self.versions.map.remove(&self.versions.frame); + + // restore the original frame + // println!("RESTORING BACKUP {}", frame_backup); + self.versions.frame = frame_backup; + Ok(FunctionCallOrExpression::Expression( E::from(expression).into_inner(), )) } - Ok(Output::Incomplete((statements, expression), delta_for_loop_versions)) => { - self.complete = false; - self.statement_buffer.extend(statements); - self.for_loop_versions_after.extend(delta_for_loop_versions); - Ok(FunctionCallOrExpression::Expression( - E::from(expression.clone()).into_inner(), - )) - } Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!( "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => { - self.complete = false; - - Ok(FunctionCallOrExpression::Expression(E::function_call( - key, generics, arguments, - ))) - } + Err(InlineError::NonConstant(key, generics, arguments, _)) => Ok( + FunctionCallOrExpression::Expression(E::function_call(key, generics, arguments)), + ), Err(InlineError::Flat(embed, generics, arguments, output_type)) => { let identifier = Identifier::from(CoreIdentifier::Call(0)).version( *self .versions + .map + .entry(self.versions.frame) + .or_default() .entry(CoreIdentifier::Call(0)) .and_modify(|e| *e += 1) // if it was already declared, we increment .or_insert(0), @@ -293,23 +277,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, b: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { - // backup the statements and continue with a fresh state - let statement_buffer = std::mem::take(&mut self.statement_buffer); - - let block = fold_block_expression(self, b)?; - - // put the original statements back and extract the statements created by visiting the block - let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); - - // return the visited block, augmented with the statements created while visiting it - Ok(BlockExpression { - statements: block - .statements - .into_iter() - .chain(extra_statements) - .collect(), - ..block - }) + // // backup the statements and continue with a fresh state + // let statement_buffer = std::mem::take(&mut self.statement_buffer); + + // let block = fold_block_expression(self, b)?; + + // // put the original statements back and extract the statements created by visiting the block + // let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); + + // // return the visited block, augmented with the statements created while visiting it + // Ok(BlockExpression { + // statements: block + // .statements + // .into_iter() + // .chain(extra_statements) + // .collect(), + // ..block + // }) + todo!() } fn fold_canonical_constant_identifier( @@ -324,77 +309,133 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { let res = match s { - TypedStatement::For(v, from, to, statements) => { - let versions_before = self.for_loop_versions.pop().unwrap(); + TypedStatement::Definition(a, rhs) => { + let mut transformer = ShallowTransformer::with_versions(self.versions); + // println!("rhs {}", rhs); + let rhs = transformer.fold_definition_rhs(rhs); + // println!("ssa-ed {}", rhs); + let rhs = self.propagator.fold_definition_rhs(rhs).unwrap(); + // println!("propagated {}", rhs); + let rhs = self.fold_definition_rhs(rhs).unwrap(); + // println!("reduced {}", rhs); - match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(from), UExpressionInner::Value(to)) => { - let mut out_statements = vec![]; + // println!("ASSIGNEE {}", a); - // get a fresh set of versions for all variables to use as a starting point inside the loop - self.versions.values_mut().for_each(|v| *v += 1); + let a = ShallowTransformer::with_versions(self.versions).fold_assignee(a); - // add this set of versions to the substitution, pointing to the versions before the loop - register(self.substitutions, self.versions, &versions_before); + // println!("SSA ASSIGNEE {}", a); - // the versions after the loop are found by applying an offset of 1 to the versions before the loop - let versions_after = versions_before - .clone() - .into_iter() - .map(|(k, v)| (k, v + 1)) - .collect(); - - let mut transformer = ShallowTransformer::with_versions(self.versions); - - if to - from > MAX_FOR_LOOP_SIZE { - return Err(Error::LoopTooLarge(to.saturating_sub(*from))); - } - - for index in *from..*to { - let statements: Vec> = - std::iter::once(TypedStatement::definition( - v.clone().into(), - UExpression::from(index as u32).into(), - )) - .chain(statements.clone().into_iter()) - .flat_map(|s| transformer.fold_statement(s)) - .collect(); - - out_statements.extend(statements); - } - - let backups = transformer.for_loop_backups; - let blocked = transformer.blocked; - - // we know the final versions of the variables after full unrolling of the loop - // the versions after the loop need to point to these, so we add to the substitutions - register(self.substitutions, &versions_after, self.versions); - - // we may have found new for loops when unrolling this one, which means new backed up versions - // we insert these in our backup list and update our cursor - - self.for_loop_versions_after.extend(backups); - - // if the ssa transform got blocked, the reduction is not complete - self.complete &= !blocked; - - Ok(out_statements) - } - _ => { - let from = self.fold_uint_expression(from)?; - let to = self.fold_uint_expression(to)?; - self.complete = false; - self.for_loop_versions_after.push(versions_before); - Ok(vec![TypedStatement::For(v, from, to, statements)]) - } + let s = self + .propagator + .fold_statement(TypedStatement::Definition(a, rhs)) + .unwrap(); + + self.statement_buffer.drain(..).chain(s).collect::>() + } + TypedStatement::For(v, from, to, statements) => { + let mut transformer = ShallowTransformer::with_versions(self.versions); + let from = transformer.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from).unwrap(); + let to = transformer.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to).unwrap(); + + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from..*to) + .flat_map(|index| { + std::iter::once(TypedStatement::definition( + v.clone().into(), + UExpression::from(index as u32).into(), + )) + .chain(statements.clone()) + .flat_map(|s| self.fold_statement(s).unwrap()) + .collect::>() + }) + .collect(), + _ => unimplemented!(), } } - s => fold_statement(self, s), + TypedStatement::Assembly(_) => todo!(), + s => { + let mut transformer = ShallowTransformer::with_versions(self.versions); + let propagator = &mut self.propagator; + transformer + .fold_statement(s) + .into_iter() + // .inspect(|s| println!("ssa {}\n", s)) + .flat_map(|s| propagator.fold_statement(s).unwrap()) + // .inspect(|s| println!("propagated {}\n", s)) + .collect() + } }; - - res.map(|res| self.statement_buffer.drain(..).chain(res).collect()) + Ok(res) } + // fn fold_statement( + // &mut self, + // s: TypedStatement<'ast, T>, + // ) -> Result>, Self::Error> { + // let mut transformer = ShallowTransformer::with_versions(self.versions); + // let propagator = &mut self.propagator; + + // println!("FOLD_STATEMENT: {}", s); + + // let s: Vec<_> = transformer + // .fold_statement(s) + // .into_iter() + // // .inspect(|s| println!("ssa {}\n", s)) + // .flat_map(|s| propagator.fold_statement(s).unwrap()) + // // .inspect(|s| println!("propagated {}\n", s)) + // .collect(); + + // for s in &s { + // println!("OUTER: {}", s); + // } + + // let res: Vec<_> = s + // .into_iter() + // .flat_map(|s| match s { + // TypedStatement::For(v, from, to, statements) => { + // match (from.as_inner(), to.as_inner()) { + // (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from + // ..*to) + // .flat_map(|index| { + // std::iter::once(TypedStatement::definition( + // v.clone().into(), + // UExpression::from(index as u32).into(), + // )) + // .chain(statements.clone()) + // .flat_map(|s| self.fold_statement(s).unwrap()) + // .collect::>() + // }) + // .collect(), + // _ => unimplemented!(), + // } + // } + // s => { + // println!("UNROLL/INLINE STATEMENT {}", s); + + // let s = fold_statement(self, s).unwrap(); + + // for s in &self.statement_buffer { + // println!("BUFFER {}", s); + // } + + // for s in &s { + // println!("RESULT {}", s); + // } + + // self.statement_buffer.drain(..).chain(s).collect::>() + // } + // }) + // .collect(); + + // for s in &res { + // // println!("DONE: {}", s); + // } + + // Ok(res) + // } + fn fold_array_expression_inner( &mut self, array_ty: &ArrayType<'ast, T>, @@ -402,19 +443,20 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { match e { ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = self.fold_array_expression(array)?; - let from = self.fold_uint_expression(from)?; - let to = self.fold_uint_expression(to)?; - - match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { - Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - } - _ => { - self.complete = false; - Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - } - } + // let array = self.fold_array_expression(array)?; + // let from = self.fold_uint_expression(from)?; + // let to = self.fold_uint_expression(to)?; + + // match (from.as_inner(), to.as_inner()) { + // (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { + // Ok(ArrayExpressionInner::Slice(box array, box from, box to)) + // } + // _ => { + // self.complete = false; + // Ok(ArrayExpressionInner::Slice(box array, box from, box to)) + // } + // } + todo!() } _ => fold_array_expression_inner(self, array_ty, e), } @@ -443,7 +485,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E match main_function.signature.generics.len() { 0 => { - let main_function = reduce_function(main_function, GGenericsAssignment::default(), &p)?; + let main_function = reduce_function(main_function, &p)?; Ok(TypedProgram { main: p.main.clone(), @@ -467,83 +509,129 @@ pub fn reduce_program(p: TypedProgram) -> Result, E fn reduce_function<'ast, T: Field>( f: TypedFunction<'ast, T>, - generics: ConcreteGenericsAssignment<'ast>, program: &TypedProgram<'ast, T>, ) -> Result, Error> { let mut versions = Versions::default(); - let mut constants = Constants::default(); - let f = match ShallowTransformer::transform(f, &generics, &mut versions) { - Output::Complete(f) => Ok(f), - Output::Incomplete(new_f, new_for_loop_versions) => { - let mut for_loop_versions = new_for_loop_versions; + assert!(f.signature.generics.is_empty()); - let mut f = new_f; + // let f = match ShallowTransformer::transform(f, &generics, &mut versions) { + // Output::Complete(f) => Ok(f), + // Output::Incomplete(new_f, new_for_loop_versions) => { + // let mut for_loop_versions = new_for_loop_versions; - let mut substitutions = Substitutions::default(); + // let mut f = Propagator::with_constants(&mut constants) + // .fold_function(new_f) + // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - let mut hash = None; + // let mut substitutions = Substitutions::default(); - loop { - let mut reducer = Reducer::new( - program, - &mut versions, - &mut substitutions, - for_loop_versions, - ); + // let mut hash = None; - let new_f = TypedFunction { - statements: f - .statements - .into_iter() - .map(|s| reducer.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), - ..f - }; + // let mut len = f.statements.len(); - assert!(reducer.for_loop_versions.is_empty()); + // // println!("{}", f); - match reducer.complete { - true => { - substitutions = substitutions.canonicalize(); + // loop { + // let mut reducer = Reducer::new( + // program, + // &mut versions, + // &mut substitutions, + // for_loop_versions, + // &mut constants, + // ); - let new_f = Sub::new(&substitutions).fold_function(new_f); + // println!("reduce"); - let new_f = Propagator::with_constants(&mut constants) - .fold_function(new_f) - .map_err(|e| Error::Incompatible(format!("{}", e)))?; + // let new_f = TypedFunction { + // statements: f + // .statements + // .into_iter() + // .map(|s| reducer.fold_statement(s)) + // .collect::, _>>()? + // .into_iter() + // .flatten() + // .collect(), + // ..f + // }; - break Ok(new_f); - } - false => { - for_loop_versions = reducer.for_loop_versions_after; + // println!("done"); - let new_f = Sub::new(&substitutions).fold_function(new_f); + // // println!("after reduction {}", new_f); - f = Propagator::with_constants(&mut constants) - .fold_function(new_f) - .map_err(|e| Error::Incompatible(format!("{}", e)))?; + // println!( + // "count {}, unrolled {} loops", + // new_f.statements.len(), + // reducer.credits + // ); - let new_hash = Some(compute_hash(&f)); + // assert!(reducer.for_loop_versions.is_empty()); - if new_hash == hash { - break Err(Error::NoProgress); - } else { - hash = new_hash - } - } - } - } - } - }?; + // match reducer.complete { + // true => { + // substitutions = substitutions.canonicalize(); + + // let new_f = Sub::new(&substitutions).fold_function(new_f); + + // // println!("after last sub {}", new_f); + + // // let new_f = Propagator::with_constants(&mut constants) + // // .fold_function(new_f) + // // .map_err(|e| Error::Incompatible(format!("{}", e)))?; + + // // println!("after last prop {}", new_f); + + // break Ok(new_f); + // } + // false => { + // for_loop_versions = reducer.for_loop_versions_after; + + // println!("canonicalize"); + + // // substitutions = substitutions.canonicalize(); + + // // let new_f = Sub::new(&substitutions).fold_function(new_f); + + // println!("done"); + // // println!("after sub {}", new_f); + + // println!("propagate"); + + // // f = Propagator::with_constants(&mut constants) + // // .fold_function(new_f) + // // .map_err(|e| Error::Incompatible(format!("{}", e)))?; + + // println!("done"); + + // f = new_f; + + // // println!("after prop {}", f); + + // let new_len = f.statements.len(); + + // if new_len == len { + // let new_hash = Some(compute_hash(&f)); + + // if new_hash == hash { + // break Err(Error::NoProgress); + // } else { + // hash = new_hash; + // } + // } else { + // len = new_len; + // } + // } + // } + // } + // } + // }?; + + // Propagator::with_constants(&mut constants) + // .fold_function(f) + // .map_err(|e| Error::Incompatible(format!("{}", e))) - Propagator::with_constants(&mut constants) - .fold_function(f) - .map_err(|e| Error::Incompatible(format!("{}", e))) + Reducer::new(program, &mut versions, &mut constants).fold_function(f) } fn compute_hash(f: &TypedFunction) -> u64 { diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index a071a0446..aa4ab6a2e 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -27,54 +27,41 @@ // } use zokrates_ast::typed::folder::*; +use zokrates_ast::typed::identifier::FrameIdentifier; use zokrates_ast::typed::types::ConcreteGenericsAssignment; use zokrates_ast::typed::types::Type; use zokrates_ast::typed::*; use zokrates_field::Field; -use super::{Output, Versions}; +use super::Versions; pub struct ShallowTransformer<'ast, 'a> { // version index for any variable name pub versions: &'a mut Versions<'ast>, - // A backup of the versions before each for-loop - pub for_loop_backups: Vec>, - // whether all statements could be unrolled so far. Loops with variable bounds cannot. - pub blocked: bool, } impl<'ast, 'a> ShallowTransformer<'ast, 'a> { pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self { - ShallowTransformer { - versions, - for_loop_backups: Vec::default(), - blocked: false, - } - } - - // increase all versions by 1 and return the old versions - fn create_version_gap(&mut self) -> Versions<'ast> { - let ret = self.versions.clone(); - self.versions.values_mut().for_each(|v| *v += 1); - ret + ShallowTransformer { versions } } fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { - let version = *self - .versions + let frame_versions = self.versions.map.entry(self.versions.frame).or_default(); + + let version = frame_versions .entry(c_id.clone()) .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0); // otherwise, we start from this version + .or_default(); // otherwise, we start from this version - Identifier::from(c_id).version(version) + Identifier::from(c_id).version(*version) } fn issue_next_ssa_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { assert_eq!(v.id.version, 0); Variable { - id: self.issue_next_identifier(v.id.id), + id: self.issue_next_identifier(v.id.id.id), ..v } } @@ -83,15 +70,10 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { f: TypedFunction<'ast, T>, generics: &ConcreteGenericsAssignment<'ast>, versions: &'a mut Versions<'ast>, - ) -> Output, Vec>> { + ) -> TypedFunction<'ast, T> { let mut unroller = ShallowTransformer::with_versions(versions); - let f = unroller.fold_function(f, generics); - - match unroller.blocked { - false => Output::Complete(f), - true => Output::Incomplete(f, unroller.for_loop_backups), - } + unroller.fold_function(f, generics) } fn fold_function( @@ -116,13 +98,15 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { .collect(); for arg in &f.arguments { - let _ = self.issue_next_identifier(arg.id.id.id.clone()); + let _ = self.issue_next_identifier(arg.id.id.id.id.clone()); } fold_function(self, f) } +} - fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { +impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => { let v = self.issue_next_ssa_variable(v); @@ -131,9 +115,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { a => fold_assignee(self, a), } } -} -impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -162,9 +144,6 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { TypedStatement::For(v, from, to, stats) => { let from = self.fold_uint_expression(from); let to = self.fold_uint_expression(to); - self.blocked = true; - let versions_before_loop = self.create_version_gap(); - self.for_loop_backups.push(versions_before_loop); vec![TypedStatement::For(v, from, to, stats)] } s => fold_statement(self, s), @@ -172,25 +151,22 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { } fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - let res = Identifier { - version: *self.versions.get(&(n.id)).unwrap_or(&0), - ..n + let version = self + .versions + .map + .get(&self.versions.frame) + .unwrap() + .get(&n.id.id) + .cloned() + .unwrap_or(0); + + let id = FrameIdentifier { + frame: self.versions.frame, + ..n.id }; - res - } - fn fold_function_call_expression< - E: Id<'ast, T> + From> + Expr<'ast, T> + FunctionCall<'ast, T>, - >( - &mut self, - ty: &E::Ty, - c: FunctionCallExpression<'ast, T, E>, - ) -> FunctionCallOrExpression<'ast, T, E> { - if !c.function_key.id.starts_with('_') { - self.blocked = true; - } - - fold_function_call_expression(self, ty, c) + let res = Identifier { version, id }; + res } } diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index d3e87fcd0..70200ecd3 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -4,6 +4,8 @@ use crate::typed::types::*; use crate::typed::*; use zokrates_field::Field; +use super::identifier::FrameIdentifier; + pub trait Fold<'ast, T: Field>: Sized { fn fold>(self, f: &mut F) -> Self; } @@ -128,11 +130,12 @@ pub trait Folder<'ast, T: Field>: Sized { } fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { - let id = match n.id { - CoreIdentifier::Constant(c) => { - CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)) - } - id => id, + let id = match n.id.id.clone() { + CoreIdentifier::Constant(c) => FrameIdentifier { + id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)), + frame: 0, + }, + id => n.id, }; Identifier { id, ..n } diff --git a/zokrates_ast/src/typed/identifier.rs b/zokrates_ast/src/typed/identifier.rs index 772eb2bf2..91aaaa308 100644 --- a/zokrates_ast/src/typed/identifier.rs +++ b/zokrates_ast/src/typed/identifier.rs @@ -24,18 +24,34 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { } } -impl<'ast> From> for CoreIdentifier<'ast> { - fn from(s: CanonicalConstantIdentifier<'ast>) -> CoreIdentifier<'ast> { - CoreIdentifier::Constant(s) +impl<'ast> CoreIdentifier<'ast> { + pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { + FrameIdentifier { id: self, frame } } } +impl<'ast> From> for FrameIdentifier<'ast> { + fn from(s: CanonicalConstantIdentifier<'ast>) -> FrameIdentifier<'ast> { + FrameIdentifier::from(CoreIdentifier::Constant(s)) + } +} + +/// A identifier for a variable in a given call frame +#[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct FrameIdentifier<'ast> { + /// the id of the variable + #[serde(borrow)] + pub id: CoreIdentifier<'ast>, + /// the frame of the variable + pub frame: usize, +} + /// A identifier for a variable #[derive(Debug, PartialEq, Clone, Hash, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Identifier<'ast> { /// the id of the variable #[serde(borrow)] - pub id: CoreIdentifier<'ast>, + pub id: FrameIdentifier<'ast>, /// the version of the variable, used after SSA transformation pub version: usize, } @@ -58,7 +74,7 @@ impl<'ast> fmt::Display for ShadowedIdentifier<'ast> { if self.shadow == 0 { write!(f, "{}", self.id) } else { - write!(f, "{}_{}", self.id, self.shadow) + write!(f, "{}_s{}", self.id, self.shadow) } } } @@ -68,20 +84,45 @@ impl<'ast> fmt::Display for Identifier<'ast> { if self.version == 0 { write!(f, "{}", self.id) } else { - write!(f, "{}_{}", self.id, self.version) + write!(f, "{}_v{}", self.id, self.version) + } + } +} + +impl<'ast> fmt::Display for FrameIdentifier<'ast> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.frame == 0 { + write!(f, "{}", self.id) + } else { + write!(f, "{}_f{}", self.id, self.frame) } } } impl<'ast> From> for Identifier<'ast> { fn from(id: CanonicalConstantIdentifier<'ast>) -> Identifier<'ast> { - Identifier::from(CoreIdentifier::Constant(id)) + Identifier::from(FrameIdentifier::from(CoreIdentifier::Constant(id))) + } +} + +impl<'ast> From> for Identifier<'ast> { + fn from(id: FrameIdentifier<'ast>) -> Identifier<'ast> { + Identifier { id, version: 0 } } } impl<'ast> From> for Identifier<'ast> { fn from(id: CoreIdentifier<'ast>) -> Identifier<'ast> { - Identifier { id, version: 0 } + Identifier { + id: FrameIdentifier::from(id), + version: 0, + } + } +} + +impl<'ast> From> for FrameIdentifier<'ast> { + fn from(id: CoreIdentifier<'ast>) -> FrameIdentifier<'ast> { + FrameIdentifier { id, frame: 0 } } } @@ -107,6 +148,6 @@ impl<'ast> From<&'ast str> for CoreIdentifier<'ast> { impl<'ast> From<&'ast str> for Identifier<'ast> { fn from(id: &'ast str) -> Identifier<'ast> { - Identifier::from(CoreIdentifier::from(id)) + Identifier::from(FrameIdentifier::from(CoreIdentifier::from(id))) } } diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index 83ada241e..df8b0f638 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -1303,6 +1303,48 @@ impl<'ast, T: Field> From> for FieldElementExpression<'as } } +impl<'ast, T: Field> From> for BooleanExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => BooleanExpression::identifier(v.id), + TypedAssignee::Element(box a, index) => BooleanExpression::element(a.into(), index), + TypedAssignee::Member(box a, id) => BooleanExpression::member(a.into(), id), + TypedAssignee::Select(box a, box index) => BooleanExpression::select(a.into(), index), + } + } +} + +impl<'ast, T: Field> From> for UExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee { + TypedAssignee::Identifier(v) => { + let inner = UExpression::identifier(v.id); + match v._type { + GType::Uint(bitwidth) => inner.annotate(bitwidth), + _ => unreachable!(), + } + } + TypedAssignee::Element(box a, index) => UExpression::element(a.into(), index), + TypedAssignee::Member(box a, id) => UExpression::member(a.into(), id), + TypedAssignee::Select(box a, box index) => UExpression::select(a.into(), index), + } + } +} + +impl<'ast, T: Field> From> for TypedExpression<'ast, T> { + fn from(assignee: TypedAssignee<'ast, T>) -> Self { + match assignee.get_type() { + Type::FieldElement => FieldElementExpression::from(assignee).into(), + Type::Boolean => BooleanExpression::from(assignee).into(), + Type::Struct(_) => StructExpression::from(assignee).into(), + Type::Array(_) => ArrayExpression::from(assignee).into(), + Type::Uint(_) => UExpression::from(assignee).into(), + Type::Tuple(_) => TupleExpression::from(assignee).into(), + Type::Int => unreachable!(), + } + } +} + impl<'ast, T> Add for FieldElementExpression<'ast, T> { type Output = Self; diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index e4146c504..d0458a3c8 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -4,6 +4,8 @@ use crate::typed::types::*; use crate::typed::*; use zokrates_field::Field; +use super::identifier::FrameIdentifier; + pub trait ResultFold<'ast, T: Field>: Sized { fn fold>(self, f: &mut F) -> Result; } @@ -156,11 +158,12 @@ pub trait ResultFolder<'ast, T: Field>: Sized { } fn fold_name(&mut self, n: Identifier<'ast>) -> Result, Self::Error> { - let id = match n.id { - CoreIdentifier::Constant(c) => { - CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?) - } - id => id, + let id = match n.id.id.clone() { + CoreIdentifier::Constant(c) => FrameIdentifier { + id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?), + frame: 0, + }, + id => n.id, }; Ok(Identifier { id, ..n }) diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index a453fef3f..61f49eabf 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -241,13 +241,14 @@ impl<'ast, T: Field> From> for UExpression<'ast, T> fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { DeclarationConstant::Generic(g) => { - UExpression::identifier(CoreIdentifier::from(g).into()).annotate(UBitwidth::B32) + // UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32) + unreachable!() } DeclarationConstant::Concrete(v) => { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) } DeclarationConstant::Constant(v) => { - UExpression::identifier(CoreIdentifier::from(v).into()).annotate(UBitwidth::B32) + UExpression::identifier(FrameIdentifier::from(v).into()).annotate(UBitwidth::B32) } DeclarationConstant::Expression(e) => e.try_into().unwrap(), } @@ -1144,8 +1145,7 @@ pub fn check_type<'ast, T, S: Clone + PartialEq + PartialEq>( impl<'ast, T: Field> From> for UExpression<'ast, T> { fn from(c: CanonicalConstantIdentifier<'ast>) -> Self { - UExpression::identifier(Identifier::from(CoreIdentifier::Constant(c))) - .annotate(UBitwidth::B32) + UExpression::identifier(Identifier::from(FrameIdentifier::from(c))).annotate(UBitwidth::B32) } } @@ -1230,6 +1230,7 @@ pub use self::signature::{ try_from_g_signature, ConcreteSignature, DeclarationSignature, GSignature, Signature, }; +use super::identifier::FrameIdentifier; use super::{Id, ShadowedIdentifier}; pub mod signature { diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index f7fa95f13..d384aa646 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -79,6 +79,7 @@ impl Importer { .into_iter() .map(|s| match s.value.symbol { Symbol::Here(SymbolDefinition::Import(import)) => { + log::debug!("Resolve {} from {}", import, location.display()); Importer::resolve::(import, &location, resolver, modules, arena) } _ => Ok(s), diff --git a/zokrates_core/src/semantics.rs b/zokrates_core/src/semantics.rs index a4f7bb0c3..4e006eed3 100644 --- a/zokrates_core/src/semantics.rs +++ b/zokrates_core/src/semantics.rs @@ -1170,7 +1170,7 @@ impl<'ast, T: Field> Checker<'ast, T> { let id = arg.id.value.id; let info = IdentifierInfo { - id: decl_v.id.id.clone(), + id: decl_v.id.id.id.clone(), ty, is_mutable, }; From 01450c741fd598156894c1f3f8cc892cc1918d0e Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 20 Feb 2023 03:39:58 +0100 Subject: [PATCH 07/19] fix more tests --- zokrates_analysis/src/reducer/mod.rs | 187 ++++++++++++------ zokrates_analysis/src/reducer/shallow_ssa.rs | 6 +- zokrates_ast/src/typed/result_folder.rs | 2 +- .../tests/tests/uint/rotate.zok | 3 +- 4 files changed, 131 insertions(+), 67 deletions(-) diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index ceae5e9ea..bd129890c 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -24,6 +24,8 @@ use zokrates_ast::typed::types::ConcreteGenericsAssignment; use zokrates_ast::typed::types::GGenericsAssignment; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; +use zokrates_ast::typed::Typed; +use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; use zokrates_ast::typed::{ @@ -146,13 +148,27 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let generics = e .generics .into_iter() - .map(|g| g.map(|g| self.fold_uint_expression(g)).transpose()) + .map(|g| { + g.map(|g| { + let g = + ShallowTransformer::with_versions(self.versions).fold_uint_expression(g); + let g = self.propagator.fold_uint_expression(g).unwrap(); + + self.fold_uint_expression(g) + }) + .transpose() + }) .collect::>()?; let arguments = e .arguments .into_iter() - .map(|e| self.fold_expression(e)) + .map(|e| { + let e = ShallowTransformer::with_versions(self.versions).fold_expression(e); + let e = self.propagator.fold_expression(e).unwrap(); + + self.fold_expression(e) + }) .collect::>()?; // back up the current frame @@ -167,25 +183,21 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .map .insert(self.versions.frame, Default::default()); - // println!("FRAME READY TO INLINE {}", self.versions.frame); - // println!("GENERICS {:?}", generics); let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); - match res { + let res = match res { Ok((input_variables, arguments, generics_bindings, statements, expression)) => { - // println!("FRAME BEFORE REDUCING INLINE RESULT {}", self.versions.frame); - - let mut transformer = ShallowTransformer::with_versions(self.versions); - let propagator = &mut self.propagator; - - // println!("BINDINGS BEFORE {}", generics_bindings[0]); - let generics_bindings: Vec<_> = generics_bindings .into_iter() - .flat_map(|s| transformer.fold_statement(s)) - .flat_map(|s| propagator.fold_statement(s).unwrap()) + .flat_map(|s| { + ShallowTransformer::with_versions(self.versions).fold_statement(s) + }) + .flat_map(|s| self.propagator.fold_statement(s).unwrap()) + .collect::>() + .into_iter() + .flat_map(|s| self.fold_statement(s).unwrap()) .collect(); // println!("{:#?}", propagator); @@ -197,13 +209,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .into_iter() .zip(arguments) .map(|(v, a)| { - TypedStatement::definition(transformer.fold_assignee(v.into()), a) + TypedStatement::definition( + ShallowTransformer::with_versions(self.versions) + .fold_assignee(v.into()), + a, + ) }) .collect(); let input_bindings: Vec<_> = input_bindings .into_iter() - .flat_map(|s| propagator.fold_statement(s).unwrap()) + .flat_map(|s| self.propagator.fold_statement(s).unwrap()) .collect(); self.statement_buffer.extend(input_bindings); @@ -226,14 +242,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let expression = propagator.fold_expression(expression).unwrap(); - // println!("call result reduced {}", expression); - - // clean versions - self.versions.map.remove(&self.versions.frame); + let expression = self.fold_expression(expression).unwrap(); - // restore the original frame - // println!("RESTORING BACKUP {}", frame_backup); - self.versions.frame = frame_backup; + // println!("call result reduced {}", expression); Ok(FunctionCallOrExpression::Expression( E::from(expression).into_inner(), @@ -243,23 +254,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => Ok( - FunctionCallOrExpression::Expression(E::function_call(key, generics, arguments)), - ), + Err(InlineError::NonConstant(key, generics, arguments, _)) => Err(Error::NoProgress), Err(InlineError::Flat(embed, generics, arguments, output_type)) => { - let identifier = Identifier::from(CoreIdentifier::Call(0)).version( - *self - .versions - .map - .entry(self.versions.frame) - .or_default() - .entry(CoreIdentifier::Call(0)) - .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0), - ); + let identifier = + Identifier::from(CoreIdentifier::Call(0).in_frame(self.versions.frame)) + .version( + *self + .versions + .map + .entry(self.versions.frame) + .or_default() + .entry(CoreIdentifier::Call(0)) + .and_modify(|e| *e += 1) // if it was already declared, we increment + .or_insert(0), + ); let var = Variable::immutable(identifier.clone(), output_type); - let v = var.clone().into(); + + let v: TypedAssignee<'ast, T> = var.clone().into(); self.statement_buffer .push(TypedStatement::embed_call_definition( @@ -270,7 +282,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { identifier, ))) } - } + }; + + // clean versions + self.versions.map.remove(&self.versions.frame); + + // restore the original frame + // println!("RESTORING BACKUP {}", frame_backup); + self.versions.frame = frame_backup; + + res } fn fold_block_expression>( @@ -310,6 +331,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result>, Self::Error> { let res = match s { TypedStatement::Definition(a, rhs) => { + // usually we transform and then propagate + // for definitions we need special treatment: we transform and propagate the rhs (which can contain function calls) + // then we reduce the rhs to remove the function calls + // only then we transform and propagate the assignee + let mut transformer = ShallowTransformer::with_versions(self.versions); // println!("rhs {}", rhs); let rhs = transformer.fold_definition_rhs(rhs); @@ -321,23 +347,28 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { // println!("ASSIGNEE {}", a); + // println!("{:?}", self.versions); + let a = ShallowTransformer::with_versions(self.versions).fold_assignee(a); - // println!("SSA ASSIGNEE {}", a); + // println!("{:?}", self.versions); + + // println!("definition before propagation {}", TypedStatement::Definition(a.clone(), rhs.clone())); - let s = self - .propagator + self.propagator .fold_statement(TypedStatement::Definition(a, rhs)) - .unwrap(); + .unwrap() - self.statement_buffer.drain(..).chain(s).collect::>() + // println!("final definition size: {}", s.len()); } TypedStatement::For(v, from, to, statements) => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - let from = transformer.fold_uint_expression(from); + let from = + ShallowTransformer::with_versions(self.versions).fold_uint_expression(from); let from = self.propagator.fold_uint_expression(from).unwrap(); - let to = transformer.fold_uint_expression(to); + let from = self.fold_uint_expression(from).unwrap(); + let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to); let to = self.propagator.fold_uint_expression(to).unwrap(); + let to = self.fold_uint_expression(to).unwrap(); match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from..*to) @@ -350,11 +381,29 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .flat_map(|s| self.fold_statement(s).unwrap()) .collect::>() }) - .collect(), + .collect::>(), _ => unimplemented!(), } } TypedStatement::Assembly(_) => todo!(), + TypedStatement::Return(e) => { + let mut transformer = ShallowTransformer::with_versions(self.versions); + + let e = transformer.fold_expression(e); + let e = self.propagator.fold_expression(e).unwrap(); + vec![TypedStatement::Return(self.fold_expression(e).unwrap())] + } + TypedStatement::Assertion(e, error) => { + let mut transformer = ShallowTransformer::with_versions(self.versions); + + let e = transformer.fold_boolean_expression(e); + let e = self.propagator.fold_boolean_expression(e).unwrap(); + + vec![TypedStatement::Assertion( + self.fold_boolean_expression(e).unwrap(), + error, + )] + } s => { let mut transformer = ShallowTransformer::with_versions(self.versions); let propagator = &mut self.propagator; @@ -363,11 +412,19 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .into_iter() // .inspect(|s| println!("ssa {}\n", s)) .flat_map(|s| propagator.fold_statement(s).unwrap()) + .collect::>() + .into_iter() + .flat_map(|s| fold_statement(self, s).unwrap()) // .inspect(|s| println!("propagated {}\n", s)) .collect() } }; - Ok(res) + + Ok(self + .statement_buffer + .drain(..) + .chain(res) + .collect::>()) } // fn fold_statement( @@ -443,20 +500,26 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { match e { ArrayExpressionInner::Slice(box array, box from, box to) => { - // let array = self.fold_array_expression(array)?; - // let from = self.fold_uint_expression(from)?; - // let to = self.fold_uint_expression(to)?; - - // match (from.as_inner(), to.as_inner()) { - // (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { - // Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - // } - // _ => { - // self.complete = false; - // Ok(ArrayExpressionInner::Slice(box array, box from, box to)) - // } - // } - todo!() + let array = + ShallowTransformer::with_versions(self.versions).fold_array_expression(array); + let array = self.propagator.fold_array_expression(array).unwrap(); + let array = self.fold_array_expression(array).unwrap(); + let from = + ShallowTransformer::with_versions(self.versions).fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from).unwrap(); + let from = self.fold_uint_expression(from).unwrap(); + let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to).unwrap(); + let to = self.fold_uint_expression(to).unwrap(); + + match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { + Ok(ArrayExpressionInner::Slice(box array, box from, box to)) + } + _ => { + todo!("non constant slice bounds") + } + } } _ => fold_array_expression_inner(self, array_ty, e), } diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index aa4ab6a2e..5e744fe23 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -103,10 +103,8 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { fold_function(self, f) } -} -impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { - fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + pub fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { match a { TypedAssignee::Identifier(v) => { let v = self.issue_next_ssa_variable(v); @@ -115,7 +113,9 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { a => fold_assignee(self, a), } } +} +impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index d0458a3c8..25c84c292 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -163,7 +163,7 @@ pub trait ResultFolder<'ast, T: Field>: Sized { id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)?), frame: 0, }, - id => n.id, + _ => n.id, }; Ok(Identifier { id, ..n }) diff --git a/zokrates_core_test/tests/tests/uint/rotate.zok b/zokrates_core_test/tests/tests/uint/rotate.zok index c2ba820f9..c593ae18e 100644 --- a/zokrates_core_test/tests/tests/uint/rotate.zok +++ b/zokrates_core_test/tests/tests/uint/rotate.zok @@ -3,7 +3,8 @@ import "utils/casts/u32_from_bits" as from_bits; def right_rotate(u32 e) -> u32 { bool[32] b = to_bits(e); - return from_bits([...b[32-N..], ...b[..32-N]]); + u32 res = from_bits([...b[32-N..], ...b[..32-N]]); + return res; } def main(u32 e) -> u32 { From 99f145566fce3c80e3be2832ea07f9943c635152 Mon Sep 17 00:00:00 2001 From: dark64 Date: Mon, 20 Feb 2023 19:24:16 +0100 Subject: [PATCH 08/19] fix mpc init message --- zokrates_cli/src/ops/mpc/init.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zokrates_cli/src/ops/mpc/init.rs b/zokrates_cli/src/ops/mpc/init.rs index eb7ba16e4..92a972482 100644 --- a/zokrates_cli/src/ops/mpc/init.rs +++ b/zokrates_cli/src/ops/mpc/init.rs @@ -24,8 +24,8 @@ pub fn subcommand() -> App<'static, 'static> { .arg( Arg::with_name("radix-path") .short("r") - .long("radix-dir") - .help("Path of the directory containing parameters for various 2^m circuit depths (phase1radix2m{0..=m})") + .long("radix-path") + .help("Path of the phase1radix2m{n} file") .value_name("PATH") .takes_value(true) .required(true), From 1bb524f6a2f81ac56113b4557ed963dcd0203df8 Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 20 Feb 2023 23:17:02 +0100 Subject: [PATCH 09/19] clean --- zokrates_analysis/src/propagation.rs | 21 +- .../src/reducer/constants_reader.rs | 12 +- .../src/reducer/constants_writer.rs | 6 +- zokrates_analysis/src/reducer/inline.rs | 5 +- zokrates_analysis/src/reducer/mod.rs | 556 +++++------------- zokrates_analysis/src/reducer/shallow_ssa.rs | 71 +-- zokrates_ast/src/typed/folder.rs | 2 +- zokrates_ast/src/typed/types.rs | 2 +- 8 files changed, 202 insertions(+), 473 deletions(-) diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index b7701dc01..8a2f32e04 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -44,25 +44,16 @@ impl fmt::Display for Error { } } -#[derive(Debug)] -pub struct Propagator<'ast, 'a, T> { +#[derive(Debug, Default)] +pub struct Propagator<'ast, T> { // constants keeps track of constant expressions // we currently do not support partially constant expressions: `field [x, 1][1]` is not considered constant, `field [0, 1][1]` is - constants: &'a mut Constants<'ast, T>, + constants: Constants<'ast, T>, } -impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { - pub fn with_constants(constants: &'a mut Constants<'ast, T>) -> Self { - Propagator { constants } - } - +impl<'ast, T: Field> Propagator<'ast, T> { pub fn propagate(p: TypedProgram<'ast, T>) -> Result, Error> { - let mut constants = Constants::new(); - - Propagator { - constants: &mut constants, - } - .fold_program(p) + Propagator::default().fold_program(p) } // get a mutable reference to the constant corresponding to a given assignee if any, otherwise @@ -141,7 +132,7 @@ impl<'ast, 'a, T: Field> Propagator<'ast, 'a, T> { } } -impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Propagator<'ast, 'a, T> { +impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { type Error = Error; fn fold_program(&mut self, p: TypedProgram<'ast, T>) -> Result, Error> { diff --git a/zokrates_analysis/src/reducer/constants_reader.rs b/zokrates_analysis/src/reducer/constants_reader.rs index f0ec252e8..f991f77fd 100644 --- a/zokrates_analysis/src/reducer/constants_reader.rs +++ b/zokrates_analysis/src/reducer/constants_reader.rs @@ -65,7 +65,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -94,7 +94,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -124,7 +124,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -152,7 +152,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -182,7 +182,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, @@ -212,7 +212,7 @@ impl<'a, 'ast, T: Field> Folder<'ast, T> for ConstantsReader<'a, 'ast, T> { id: FrameIdentifier { id: CoreIdentifier::Constant(c), - frame, + frame: _, }, version, }, diff --git a/zokrates_analysis/src/reducer/constants_writer.rs b/zokrates_analysis/src/reducer/constants_writer.rs index d38daf150..50a6d25ad 100644 --- a/zokrates_analysis/src/reducer/constants_writer.rs +++ b/zokrates_analysis/src/reducer/constants_writer.rs @@ -5,9 +5,9 @@ use crate::reducer::{ }; use std::collections::{BTreeMap, HashSet}; use zokrates_ast::typed::{ - result_folder::*, types::ConcreteGenericsAssignment, Constant, OwnedTypedModuleId, Typed, - TypedConstant, TypedConstantSymbol, TypedConstantSymbolDeclaration, TypedModuleId, - TypedProgram, TypedSymbolDeclaration, UExpression, + result_folder::*, Constant, OwnedTypedModuleId, Typed, TypedConstant, TypedConstantSymbol, + TypedConstantSymbolDeclaration, TypedModuleId, TypedProgram, TypedSymbolDeclaration, + UExpression, }; use zokrates_field::Field; diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index f1a7229e6..7e2436cdc 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -26,13 +26,10 @@ // - The body of the function is in SSA form // - The return value(s) are assigned to internal variables -use crate::reducer::ShallowTransformer; -use crate::reducer::Versions; - use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType}; use zokrates_ast::typed::CoreIdentifier; -use zokrates_ast::typed::Identifier; + use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::UBitwidth; use zokrates_ast::typed::{ diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index bd129890c..51dcc3aab 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -18,30 +18,26 @@ mod shallow_ssa; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; -use zokrates_ast::typed::identifier::FrameIdentifier; use zokrates_ast::typed::result_folder::*; -use zokrates_ast::typed::types::ConcreteGenericsAssignment; -use zokrates_ast::typed::types::GGenericsAssignment; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; -use zokrates_ast::typed::Typed; +use zokrates_ast::typed::TypedAssemblyStatement; use zokrates_ast::typed::TypedAssignee; -use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; - use zokrates_ast::typed::{ ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, - FunctionCallExpression, FunctionCallOrExpression, Id, Identifier, OwnedTypedModuleId, - TypedExpression, TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, - TypedModule, TypedProgram, TypedStatement, UExpression, UExpressionInner, + FunctionCallExpression, FunctionCallOrExpression, Id, OwnedTypedModuleId, TypedExpression, + TypedFunction, TypedFunctionSymbol, TypedFunctionSymbolDeclaration, TypedModule, TypedProgram, + TypedStatement, UExpression, UExpressionInner, }; +use zokrates_ast::typed::{CanonicalConstantIdentifier, EmbedCall, Variable}; -use zokrates_ast::zir::result_folder::fold_assembly_statement; use zokrates_field::Field; use self::constants_writer::ConstantsWriter; use self::shallow_ssa::ShallowTransformer; -use crate::propagation::{Constants, Propagator}; +use crate::propagation; +use crate::propagation::Propagator; use std::fmt; @@ -55,7 +51,6 @@ pub type ConstantDefinitions<'ast, T> = #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Versions<'ast> { map: HashMap, usize>>, - frame: usize, } impl<'ast> Versions<'ast> { @@ -69,15 +64,15 @@ impl<'ast> Versions<'ast> { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] pub enum Error { Incompatible(String), GenericsInMain, - // TODO: give more details about what's blocking the progress - NoProgress, LoopTooLarge(u128), ConstantReduction(String, OwnedTypedModuleId), + NonConstant(String), Type(String), + Propagation(propagation::Error), } impl fmt::Display for Error { @@ -89,39 +84,36 @@ impl fmt::Display for Error { s ), Error::GenericsInMain => write!(f, "Cannot generate code for generic function"), - Error::NoProgress => write!(f, "Failed to unroll or inline program. Check that main function arguments aren't used as array size or for-loop bounds"), Error::LoopTooLarge(size) => write!(f, "Found a loop of size {}, which is larger than the maximum allowed of {}. Check the loop bounds, especially for underflows", size, MAX_FOR_LOOP_SIZE), Error::ConstantReduction(name, module) => write!(f, "Failed to reduce constant `{}` in module `{}` to a literal, try simplifying its declaration", name, module.display()), - Error::Type(message) => write!(f, "{}", message), + Error::NonConstant(s) => write!(f, "{}", s), + Error::Type(s) => write!(f, "{}", s), + Error::Propagation(e) => write!(f, "{}", e), } } } +impl From for Error { + fn from(e: propagation::Error) -> Self { + Self::Propagation(e) + } +} + #[derive(Debug)] struct Reducer<'ast, 'a, T> { - propagator: Propagator<'ast, 'a, T>, - statement_buffer: Vec>, - latest_frame: usize, - versions: &'a mut Versions<'ast>, program: &'a TypedProgram<'ast, T>, + propagator: Propagator<'ast, T>, + ssa: ShallowTransformer<'ast>, + statement_buffer: Vec>, } impl<'ast, 'a, T: Field> Reducer<'ast, 'a, T> { - fn new( - program: &'a TypedProgram<'ast, T>, - versions: &'a mut Versions<'ast>, - constants: &'a mut Constants<'ast, T>, - ) -> Self { - // println!("create reducer with"); - // println!("{} versions", versions.len()); - // println!("{} constants", constants.len()); - + fn new(program: &'a TypedProgram<'ast, T>) -> Self { Reducer { - propagator: Propagator::with_constants(constants), + propagator: Propagator::default(), + ssa: ShallowTransformer::default(), statement_buffer: vec![], - latest_frame: 0, program, - versions, } } } @@ -133,8 +125,9 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, p: DeclarationParameter<'ast, T>, ) -> Result, Self::Error> { + // this is only used on the entry point let id = p.id.id.id.id.clone(); - assert!(self.versions.insert_in_frame(id, 0, 0).is_none()); + assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none()); Ok(p) } @@ -150,9 +143,8 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .into_iter() .map(|g| { g.map(|g| { - let g = - ShallowTransformer::with_versions(self.versions).fold_uint_expression(g); - let g = self.propagator.fold_uint_expression(g).unwrap(); + let g = self.ssa.fold_uint_expression(g); + let g = self.propagator.fold_uint_expression(g)?; self.fold_uint_expression(g) }) @@ -164,43 +156,30 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .arguments .into_iter() .map(|e| { - let e = ShallowTransformer::with_versions(self.versions).fold_expression(e); - let e = self.propagator.fold_expression(e).unwrap(); + let e = self.ssa.fold_expression(e); + let e = self.propagator.fold_expression(e)?; self.fold_expression(e) }) .collect::>()?; - // back up the current frame - let frame_backup = self.versions.frame; - - // create a new frame - self.latest_frame += 1; - - // point the versions to this frame - self.versions.frame = self.latest_frame; - self.versions - .map - .insert(self.versions.frame, Default::default()); - - // println!("GENERICS {:?}", generics); + self.ssa.push_call_frame(); - let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); + let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program); let res = match res { Ok((input_variables, arguments, generics_bindings, statements, expression)) => { - let generics_bindings: Vec<_> = generics_bindings + let generics_bindings = generics_bindings .into_iter() - .flat_map(|s| { - ShallowTransformer::with_versions(self.versions).fold_statement(s) - }) - .flat_map(|s| self.propagator.fold_statement(s).unwrap()) - .collect::>() + .flat_map(|s| self.ssa.fold_statement(s)) + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? .into_iter() - .flat_map(|s| self.fold_statement(s).unwrap()) - .collect(); - - // println!("{:#?}", propagator); + .flatten() + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); self.statement_buffer.extend(generics_bindings); @@ -208,43 +187,32 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let input_bindings: Vec<_> = input_variables .into_iter() .zip(arguments) - .map(|(v, a)| { - TypedStatement::definition( - ShallowTransformer::with_versions(self.versions) - .fold_assignee(v.into()), - a, - ) - }) + .map(|(v, a)| TypedStatement::definition(self.ssa.fold_assignee(v.into()), a)) .collect(); - let input_bindings: Vec<_> = input_bindings + let input_bindings = input_bindings .into_iter() - .flat_map(|s| self.propagator.fold_statement(s).unwrap()) - .collect(); + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); self.statement_buffer.extend(input_bindings); - let statements: Vec<_> = statements + let statements = statements .into_iter() - .flat_map(|s| self.fold_statement(s).unwrap()) - .collect(); - - // println!("FRAME READY TO SSA {}", self.versions.frame); - - let mut transformer = ShallowTransformer::with_versions(self.versions); - let propagator = &mut self.propagator; + .map(|s| self.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); self.statement_buffer.extend(statements); - // println!("call result {}", expression); - - let expression = transformer.fold_expression(expression); + let expression = self.ssa.fold_expression(expression); - let expression = propagator.fold_expression(expression).unwrap(); + let expression = self.propagator.fold_expression(expression)?; - let expression = self.fold_expression(expression).unwrap(); - - // println!("call result reduced {}", expression); + let expression = self.fold_expression(expression)?; Ok(FunctionCallOrExpression::Expression( E::from(expression).into_inner(), @@ -254,20 +222,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => Err(Error::NoProgress), + Err(InlineError::NonConstant(key, generics, arguments, _)) => { + Err(Error::NonConstant(format!( + "Generic parameters must be compile-time constants, found {}", + FunctionCallExpression::<_, E>::new(key, generics, arguments) + ))) + } Err(InlineError::Flat(embed, generics, arguments, output_type)) => { - let identifier = - Identifier::from(CoreIdentifier::Call(0).in_frame(self.versions.frame)) - .version( - *self - .versions - .map - .entry(self.versions.frame) - .or_default() - .entry(CoreIdentifier::Call(0)) - .and_modify(|e| *e += 1) // if it was already declared, we increment - .or_insert(0), - ); + let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); let var = Variable::immutable(identifier.clone(), output_type); @@ -284,12 +246,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { } }; - // clean versions - self.versions.map.remove(&self.versions.frame); - - // restore the original frame - // println!("RESTORING BACKUP {}", frame_backup); - self.versions.frame = frame_backup; + self.ssa.pop_call_frame(); res } @@ -298,24 +255,44 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, b: BlockExpression<'ast, T, E>, ) -> Result, Self::Error> { - // // backup the statements and continue with a fresh state - // let statement_buffer = std::mem::take(&mut self.statement_buffer); - - // let block = fold_block_expression(self, b)?; - - // // put the original statements back and extract the statements created by visiting the block - // let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); - - // // return the visited block, augmented with the statements created while visiting it - // Ok(BlockExpression { - // statements: block - // .statements - // .into_iter() - // .chain(extra_statements) - // .collect(), - // ..block - // }) - todo!() + // backup the statements and continue with a fresh state + let statement_buffer = std::mem::take(&mut self.statement_buffer); + + let block = fold_block_expression(self, b)?; + + // put the original statements back and extract the statements created by visiting the block + let extra_statements = std::mem::replace(&mut self.statement_buffer, statement_buffer); + + // return the visited block, augmented with the statements created while visiting it + Ok(BlockExpression { + statements: block + .statements + .into_iter() + .chain(extra_statements) + .collect(), + ..block + }) + } + + fn fold_assembly_statement( + &mut self, + s: TypedAssemblyStatement<'ast, T>, + ) -> Result>, Self::Error> { + Ok(match s { + TypedAssemblyStatement::Assignment(a, e) => { + vec![TypedAssemblyStatement::Assignment( + self.fold_assignee(a)?, + self.fold_expression(e)?, + )] + } + TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { + vec![TypedAssemblyStatement::Constraint( + self.fold_field_expression(lhs)?, + self.fold_field_expression(rhs)?, + metadata, + )] + } + }) } fn fold_canonical_constant_identifier( @@ -336,163 +313,77 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { // then we reduce the rhs to remove the function calls // only then we transform and propagate the assignee - let mut transformer = ShallowTransformer::with_versions(self.versions); - // println!("rhs {}", rhs); - let rhs = transformer.fold_definition_rhs(rhs); - // println!("ssa-ed {}", rhs); - let rhs = self.propagator.fold_definition_rhs(rhs).unwrap(); - // println!("propagated {}", rhs); - let rhs = self.fold_definition_rhs(rhs).unwrap(); - // println!("reduced {}", rhs); - - // println!("ASSIGNEE {}", a); - - // println!("{:?}", self.versions); + let rhs = self.ssa.fold_definition_rhs(rhs); + let rhs = self.propagator.fold_definition_rhs(rhs)?; + let rhs = self.fold_definition_rhs(rhs)?; - let a = ShallowTransformer::with_versions(self.versions).fold_assignee(a); - - // println!("{:?}", self.versions); - - // println!("definition before propagation {}", TypedStatement::Definition(a.clone(), rhs.clone())); + let a = self.ssa.fold_assignee(a); self.propagator - .fold_statement(TypedStatement::Definition(a, rhs)) - .unwrap() - - // println!("final definition size: {}", s.len()); + .fold_statement(TypedStatement::Definition(a, rhs))? } TypedStatement::For(v, from, to, statements) => { - let from = - ShallowTransformer::with_versions(self.versions).fold_uint_expression(from); - let from = self.propagator.fold_uint_expression(from).unwrap(); - let from = self.fold_uint_expression(from).unwrap(); - let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to); - let to = self.propagator.fold_uint_expression(to).unwrap(); - let to = self.fold_uint_expression(to).unwrap(); + let from = self.ssa.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { - (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from..*to) + (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from + ..*to) .flat_map(|index| { std::iter::once(TypedStatement::definition( v.clone().into(), UExpression::from(index as u32).into(), )) .chain(statements.clone()) - .flat_map(|s| self.fold_statement(s).unwrap()) + .map(|s| self.fold_statement(s)) .collect::>() }) - .collect::>(), - _ => unimplemented!(), - } + .collect::, _>>()? + .into_iter() + .flatten() + .collect()), + _ => Err(Error::NonConstant(format!( + "Expected loop bounds to be constant, found {}..{}", + from, to + ))), + }? } - TypedStatement::Assembly(_) => todo!(), TypedStatement::Return(e) => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - - let e = transformer.fold_expression(e); - let e = self.propagator.fold_expression(e).unwrap(); - vec![TypedStatement::Return(self.fold_expression(e).unwrap())] + let e = self.ssa.fold_expression(e); + let e = self.propagator.fold_expression(e)?; + vec![TypedStatement::Return(self.fold_expression(e)?)] } TypedStatement::Assertion(e, error) => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - - let e = transformer.fold_boolean_expression(e); - let e = self.propagator.fold_boolean_expression(e).unwrap(); + let e = self.ssa.fold_boolean_expression(e); + let e = self.propagator.fold_boolean_expression(e)?; vec![TypedStatement::Assertion( - self.fold_boolean_expression(e).unwrap(), + self.fold_boolean_expression(e)?, error, )] } - s => { - let mut transformer = ShallowTransformer::with_versions(self.versions); - let propagator = &mut self.propagator; - transformer - .fold_statement(s) - .into_iter() - // .inspect(|s| println!("ssa {}\n", s)) - .flat_map(|s| propagator.fold_statement(s).unwrap()) - .collect::>() - .into_iter() - .flat_map(|s| fold_statement(self, s).unwrap()) - // .inspect(|s| println!("propagated {}\n", s)) - .collect() - } + s => self + .ssa + .fold_statement(s) + .into_iter() + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten() + .map(|s| fold_statement(self, s)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(), }; - Ok(self - .statement_buffer - .drain(..) - .chain(res) - .collect::>()) + Ok(self.statement_buffer.drain(..).chain(res).collect()) } - // fn fold_statement( - // &mut self, - // s: TypedStatement<'ast, T>, - // ) -> Result>, Self::Error> { - // let mut transformer = ShallowTransformer::with_versions(self.versions); - // let propagator = &mut self.propagator; - - // println!("FOLD_STATEMENT: {}", s); - - // let s: Vec<_> = transformer - // .fold_statement(s) - // .into_iter() - // // .inspect(|s| println!("ssa {}\n", s)) - // .flat_map(|s| propagator.fold_statement(s).unwrap()) - // // .inspect(|s| println!("propagated {}\n", s)) - // .collect(); - - // for s in &s { - // println!("OUTER: {}", s); - // } - - // let res: Vec<_> = s - // .into_iter() - // .flat_map(|s| match s { - // TypedStatement::For(v, from, to, statements) => { - // match (from.as_inner(), to.as_inner()) { - // (UExpressionInner::Value(from), UExpressionInner::Value(to)) => (*from - // ..*to) - // .flat_map(|index| { - // std::iter::once(TypedStatement::definition( - // v.clone().into(), - // UExpression::from(index as u32).into(), - // )) - // .chain(statements.clone()) - // .flat_map(|s| self.fold_statement(s).unwrap()) - // .collect::>() - // }) - // .collect(), - // _ => unimplemented!(), - // } - // } - // s => { - // println!("UNROLL/INLINE STATEMENT {}", s); - - // let s = fold_statement(self, s).unwrap(); - - // for s in &self.statement_buffer { - // println!("BUFFER {}", s); - // } - - // for s in &s { - // println!("RESULT {}", s); - // } - - // self.statement_buffer.drain(..).chain(s).collect::>() - // } - // }) - // .collect(); - - // for s in &res { - // // println!("DONE: {}", s); - // } - - // Ok(res) - // } - fn fold_array_expression_inner( &mut self, array_ty: &ArrayType<'ast, T>, @@ -500,25 +391,24 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { match e { ArrayExpressionInner::Slice(box array, box from, box to) => { - let array = - ShallowTransformer::with_versions(self.versions).fold_array_expression(array); - let array = self.propagator.fold_array_expression(array).unwrap(); - let array = self.fold_array_expression(array).unwrap(); - let from = - ShallowTransformer::with_versions(self.versions).fold_uint_expression(from); - let from = self.propagator.fold_uint_expression(from).unwrap(); - let from = self.fold_uint_expression(from).unwrap(); - let to = ShallowTransformer::with_versions(self.versions).fold_uint_expression(to); - let to = self.propagator.fold_uint_expression(to).unwrap(); - let to = self.fold_uint_expression(to).unwrap(); + let array = self.ssa.fold_array_expression(array); + let array = self.propagator.fold_array_expression(array)?; + let array = self.fold_array_expression(array)?; + let from = self.ssa.fold_uint_expression(from); + let from = self.propagator.fold_uint_expression(from)?; + let from = self.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); + let to = self.propagator.fold_uint_expression(to)?; + let to = self.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { Ok(ArrayExpressionInner::Slice(box array, box from, box to)) } - _ => { - todo!("non constant slice bounds") - } + _ => Err(Error::NonConstant(format!( + "Slice bounds must be compile time constants, found {}", + ArrayExpressionInner::Slice(box array, box from, box to) + ))), } } _ => fold_array_expression_inner(self, array_ty, e), @@ -548,7 +438,7 @@ pub fn reduce_program(p: TypedProgram) -> Result, E match main_function.signature.generics.len() { 0 => { - let main_function = reduce_function(main_function, &p)?; + let main_function = Reducer::new(&p).fold_function(main_function)?; Ok(TypedProgram { main: p.main.clone(), @@ -574,135 +464,9 @@ fn reduce_function<'ast, T: Field>( f: TypedFunction<'ast, T>, program: &TypedProgram<'ast, T>, ) -> Result, Error> { - let mut versions = Versions::default(); - let mut constants = Constants::default(); - assert!(f.signature.generics.is_empty()); - // let f = match ShallowTransformer::transform(f, &generics, &mut versions) { - // Output::Complete(f) => Ok(f), - // Output::Incomplete(new_f, new_for_loop_versions) => { - // let mut for_loop_versions = new_for_loop_versions; - - // let mut f = Propagator::with_constants(&mut constants) - // .fold_function(new_f) - // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - // let mut substitutions = Substitutions::default(); - - // let mut hash = None; - - // let mut len = f.statements.len(); - - // // println!("{}", f); - - // loop { - // let mut reducer = Reducer::new( - // program, - // &mut versions, - // &mut substitutions, - // for_loop_versions, - // &mut constants, - // ); - - // println!("reduce"); - - // let new_f = TypedFunction { - // statements: f - // .statements - // .into_iter() - // .map(|s| reducer.fold_statement(s)) - // .collect::, _>>()? - // .into_iter() - // .flatten() - // .collect(), - // ..f - // }; - - // println!("done"); - - // // println!("after reduction {}", new_f); - - // println!( - // "count {}, unrolled {} loops", - // new_f.statements.len(), - // reducer.credits - // ); - - // assert!(reducer.for_loop_versions.is_empty()); - - // match reducer.complete { - // true => { - // substitutions = substitutions.canonicalize(); - - // let new_f = Sub::new(&substitutions).fold_function(new_f); - - // // println!("after last sub {}", new_f); - - // // let new_f = Propagator::with_constants(&mut constants) - // // .fold_function(new_f) - // // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - // // println!("after last prop {}", new_f); - - // break Ok(new_f); - // } - // false => { - // for_loop_versions = reducer.for_loop_versions_after; - - // println!("canonicalize"); - - // // substitutions = substitutions.canonicalize(); - - // // let new_f = Sub::new(&substitutions).fold_function(new_f); - - // println!("done"); - // // println!("after sub {}", new_f); - - // println!("propagate"); - - // // f = Propagator::with_constants(&mut constants) - // // .fold_function(new_f) - // // .map_err(|e| Error::Incompatible(format!("{}", e)))?; - - // println!("done"); - - // f = new_f; - - // // println!("after prop {}", f); - - // let new_len = f.statements.len(); - - // if new_len == len { - // let new_hash = Some(compute_hash(&f)); - - // if new_hash == hash { - // break Err(Error::NoProgress); - // } else { - // hash = new_hash; - // } - // } else { - // len = new_len; - // } - // } - // } - // } - // } - // }?; - - // Propagator::with_constants(&mut constants) - // .fold_function(f) - // .map_err(|e| Error::Incompatible(format!("{}", e))) - - Reducer::new(program, &mut versions, &mut constants).fold_function(f) -} - -fn compute_hash(f: &TypedFunction) -> u64 { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut s = DefaultHasher::new(); - f.hash(&mut s); - s.finish() + Reducer::new(program).fold_function(f) } #[cfg(test)] diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index 5e744fe23..0d4509506 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -28,26 +28,24 @@ use zokrates_ast::typed::folder::*; use zokrates_ast::typed::identifier::FrameIdentifier; -use zokrates_ast::typed::types::ConcreteGenericsAssignment; -use zokrates_ast::typed::types::Type; + use zokrates_ast::typed::*; use zokrates_field::Field; use super::Versions; -pub struct ShallowTransformer<'ast, 'a> { +#[derive(Debug, Default)] +pub struct ShallowTransformer<'ast> { // version index for any variable name - pub versions: &'a mut Versions<'ast>, + pub versions: Versions<'ast>, + pub frames: Vec, + pub latest_frame: usize, } -impl<'ast, 'a> ShallowTransformer<'ast, 'a> { - pub fn with_versions(versions: &'a mut Versions<'ast>) -> Self { - ShallowTransformer { versions } - } - - fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { - let frame_versions = self.versions.map.entry(self.versions.frame).or_default(); +impl<'ast> ShallowTransformer<'ast> { + pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { + let frame_versions = self.versions.map.entry(self.frame()).or_default(); let version = frame_versions .entry(c_id.clone()) @@ -66,42 +64,21 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { } } - pub fn transform( - f: TypedFunction<'ast, T>, - generics: &ConcreteGenericsAssignment<'ast>, - versions: &'a mut Versions<'ast>, - ) -> TypedFunction<'ast, T> { - let mut unroller = ShallowTransformer::with_versions(versions); - - unroller.fold_function(f, generics) + fn frame(&self) -> usize { + *self.frames.last().unwrap_or(&0) } - fn fold_function( - &mut self, - f: TypedFunction<'ast, T>, - generics: &ConcreteGenericsAssignment<'ast>, - ) -> TypedFunction<'ast, T> { - let mut f = f; - - f.statements = generics - .0 - .clone() - .into_iter() - .map(|(g, v)| { - TypedStatement::definition( - Variable::new(CoreIdentifier::from(g), Type::Uint(UBitwidth::B32), false) - .into(), - UExpression::from(v as u32).into(), - ) - }) - .chain(f.statements) - .collect(); - - for arg in &f.arguments { - let _ = self.issue_next_identifier(arg.id.id.id.id.clone()); - } + pub fn push_call_frame(&mut self) { + self.latest_frame += 1; + self.frames.push(self.latest_frame); + self.versions + .map + .insert(self.latest_frame, Default::default()); + } - fold_function(self, f) + pub fn pop_call_frame(&mut self) { + let frame = self.frames.pop().unwrap(); + self.versions.map.remove(&frame); } pub fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { @@ -115,7 +92,7 @@ impl<'ast, 'a> ShallowTransformer<'ast, 'a> { } } -impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { +impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { fn fold_assembly_statement( &mut self, s: TypedAssemblyStatement<'ast, T>, @@ -154,14 +131,14 @@ impl<'ast, 'a, T: Field> Folder<'ast, T> for ShallowTransformer<'ast, 'a> { let version = self .versions .map - .get(&self.versions.frame) + .get(&self.frame()) .unwrap() .get(&n.id.id) .cloned() .unwrap_or(0); let id = FrameIdentifier { - frame: self.versions.frame, + frame: self.frame(), ..n.id }; diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 70200ecd3..1180874fd 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -135,7 +135,7 @@ pub trait Folder<'ast, T: Field>: Sized { id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)), frame: 0, }, - id => n.id, + _id => n.id, }; Identifier { id, ..n } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index 61f49eabf..60d3792fc 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -240,7 +240,7 @@ impl<'ast, T> From for UExpression<'ast, T> { impl<'ast, T: Field> From> for UExpression<'ast, T> { fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { - DeclarationConstant::Generic(g) => { + DeclarationConstant::Generic(_g) => { // UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32) unreachable!() } From 30801c86fa208472cae8c6ce0c966759b1375d10 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 22 Feb 2023 21:30:11 +0100 Subject: [PATCH 10/19] make all tests pass, clean --- .../src/flatten_complex_types.rs | 2 - zokrates_analysis/src/lib.rs | 4 - zokrates_analysis/src/propagation.rs | 238 +++++---- zokrates_analysis/src/reducer/inline.rs | 4 +- zokrates_analysis/src/reducer/mod.rs | 308 ++++-------- zokrates_analysis/src/reducer/shallow_ssa.rs | 465 +++++++----------- zokrates_ast/src/typed/folder.rs | 12 +- zokrates_ast/src/typed/identifier.rs | 15 + zokrates_ast/src/typed/mod.rs | 54 +- zokrates_ast/src/typed/result_folder.rs | 10 +- zokrates_ast/src/typed/types.rs | 6 +- zokrates_core_test/tests/tests/call_ssa.json | 16 + zokrates_core_test/tests/tests/call_ssa.zok | 11 + 13 files changed, 440 insertions(+), 705 deletions(-) create mode 100644 zokrates_core_test/tests/tests/call_ssa.json create mode 100644 zokrates_core_test/tests/tests/call_ssa.zok diff --git a/zokrates_analysis/src/flatten_complex_types.rs b/zokrates_analysis/src/flatten_complex_types.rs index f4b81d8e1..0b834ef8a 100644 --- a/zokrates_analysis/src/flatten_complex_types.rs +++ b/zokrates_analysis/src/flatten_complex_types.rs @@ -629,8 +629,6 @@ fn fold_statement<'ast, T: Field>( }) .collect(), )], - typed::TypedStatement::PushCallLog(..) => vec![], - typed::TypedStatement::PopCallLog => vec![], typed::TypedStatement::For(..) => unreachable!(), }; diff --git a/zokrates_analysis/src/lib.rs b/zokrates_analysis/src/lib.rs index c628e7283..539fe86cb 100644 --- a/zokrates_analysis/src/lib.rs +++ b/zokrates_analysis/src/lib.rs @@ -161,10 +161,6 @@ pub fn analyse<'ast, T: Field>( let r = reduce_program(r).map_err(Error::from)?; log::trace!("\n{}", r); - log::debug!("Static analyser: Propagate"); - let r = Propagator::propagate(r)?; - log::trace!("\n{}", r); - log::debug!("Static analyser: Concretize structs"); let r = StructConcretizer::concretize(r); log::trace!("\n{}", r); diff --git a/zokrates_analysis/src/propagation.rs b/zokrates_analysis/src/propagation.rs index 8a2f32e04..7d77e86db 100644 --- a/zokrates_analysis/src/propagation.rs +++ b/zokrates_analysis/src/propagation.rs @@ -308,21 +308,12 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { } }; - // particular case of `lhs = rhs` - if TypedExpression::from(assignee.clone()) == expr { - return Ok(vec![]); - } - if expr.is_constant() { match assignee { TypedAssignee::Identifier(var) => { let expr = expr.into_canonical_constant(); - assert!( - self.constants.insert(var.clone().id, expr).is_none(), - "{}", - var - ); + assert!(self.constants.insert(var.id, expr).is_none()); Ok(vec![]) } @@ -629,8 +620,6 @@ impl<'ast, T: Field> ResultFolder<'ast, T> for Propagator<'ast, T> { _ => Ok(vec![TypedStatement::Assertion(expr, err)]), } } - s @ TypedStatement::PushCallLog(..) => Ok(vec![s]), - s @ TypedStatement::PopCallLog => Ok(vec![s]), s => fold_statement(self, s), } } @@ -1502,7 +1491,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(5))) ); } @@ -1515,7 +1504,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(1))) ); } @@ -1528,7 +1517,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(6))) ); } @@ -1541,7 +1530,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1554,15 +1543,14 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(8))) ); } #[test] fn left_shift() { - let mut constants = Constants::new(); - let mut propagator = Propagator::with_constants(&mut constants); + let mut propagator = Propagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::LeftShift( @@ -1607,8 +1595,7 @@ mod tests { #[test] fn right_shift() { - let mut constants = Constants::new(); - let mut propagator = Propagator::with_constants(&mut constants); + let mut propagator = Propagator::default(); assert_eq!( propagator.fold_field_expression(FieldElementExpression::RightShift( @@ -1676,7 +1663,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(2))) ); } @@ -1691,7 +1678,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1713,7 +1700,7 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()).fold_field_expression(e), + Propagator::default().fold_field_expression(e), Ok(FieldElementExpression::Number(Bn128Field::from(3))) ); } @@ -1735,18 +1722,15 @@ mod tests { BooleanExpression::Not(box BooleanExpression::identifier("a".into())); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_default.clone()), + Propagator::default().fold_boolean_expression(e_default.clone()), Ok(e_default) ); } @@ -1776,23 +1760,19 @@ mod tests { )); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_true), + Propagator::default().fold_boolean_expression(e_constant_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_false), + Propagator::default().fold_boolean_expression(e_constant_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_true), + Propagator::default().fold_boolean_expression(e_identifier_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_unchanged.clone()), + Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), Ok(e_identifier_unchanged) ); } @@ -1800,38 +1780,42 @@ mod tests { #[test] fn bool_eq() { assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(false), BooleanExpression::Value(false) - ))), + )) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(true), BooleanExpression::Value(true) - ))), + )) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(true), BooleanExpression::Value(false) - ))), + )) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::BoolEq(EqExpression::new( + Propagator::::default().fold_boolean_expression( + BooleanExpression::BoolEq(EqExpression::new( BooleanExpression::Value(false), BooleanExpression::Value(true) - ))), + )) + ), Ok(BooleanExpression::Value(false)) ); } @@ -1933,33 +1917,27 @@ mod tests { )); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_true), + Propagator::default().fold_boolean_expression(e_constant_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_constant_false), + Propagator::default().fold_boolean_expression(e_constant_false), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_true), + Propagator::default().fold_boolean_expression(e_identifier_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_identifier_unchanged.clone()), + Propagator::default().fold_boolean_expression(e_identifier_unchanged.clone()), Ok(e_identifier_unchanged) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_non_canonical_true), + Propagator::default().fold_boolean_expression(e_non_canonical_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_non_canonical_false), + Propagator::default().fold_boolean_expression(e_non_canonical_false), Ok(BooleanExpression::Value(false)) ); } @@ -1977,13 +1955,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2001,13 +1977,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2025,13 +1999,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2049,13 +2021,11 @@ mod tests { ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_true), + Propagator::default().fold_boolean_expression(e_true), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::with_constants(&mut Constants::new()) - .fold_boolean_expression(e_false), + Propagator::default().fold_boolean_expression(e_false), Ok(BooleanExpression::Value(false)) ); } @@ -2065,67 +2035,75 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::And( + Propagator::::default().fold_boolean_expression( + BooleanExpression::And( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); } @@ -2135,67 +2113,75 @@ mod tests { let a_bool: Identifier = "a".into(); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::identifier(a_bool.clone()) - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::identifier(a_bool.clone()), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::identifier(a_bool.clone())) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(true), box BooleanExpression::Value(true), - )), + ) + ), Ok(BooleanExpression::Value(true)) ); assert_eq!( - Propagator::::with_constants(&mut Constants::new()) - .fold_boolean_expression(BooleanExpression::Or( + Propagator::::default().fold_boolean_expression( + BooleanExpression::Or( box BooleanExpression::Value(false), box BooleanExpression::Value(false), - )), + ) + ), Ok(BooleanExpression::Value(false)) ); } diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 7e2436cdc..002228f7d 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -135,7 +135,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( } }; - let decl = get_canonical_function(&k, program); + let decl = get_canonical_function(k, program); // get an assignment of generics for this call site let assignment: ConcreteGenericsAssignment<'ast> = k @@ -190,7 +190,7 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .into_iter() .zip(inferred_signature.inputs.clone()) .map(|(p, t)| ConcreteVariable::new(p.id.id, t, false)) - .map(|v| Variable::from(v)) + .map(Variable::from) .collect(); let (statements, mut returns): (Vec<_>, Vec<_>) = f diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 51dcc3aab..b2d5da5c6 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -3,13 +3,14 @@ // - free of function calls (except for low level calls) thanks to inlining // - free of for-loops thanks to unrolling -// The process happens in two steps -// 1. Shallow SSA for the `main` function -// We turn the `main` function into SSA form, but ignoring function calls and for loops -// 2. Unroll and inline -// We go through the shallow-SSA program and -// - unroll loops -// - inline function calls. This includes applying shallow-ssa on the target function +// The process happens in a greedy way, starting from the main function +// For each statement: +// * put it in ssa form +// * propagate it +// * inline it (calling this process recursively) +// * propagate again + +// if at any time a generic parameter or loop bound is not constant, error out, because it should have been propagated to a constant by the greedy approach mod constants_reader; mod constants_writer; @@ -21,7 +22,6 @@ use std::collections::HashMap; use zokrates_ast::typed::result_folder::*; use zokrates_ast::typed::DeclarationParameter; use zokrates_ast::typed::Folder; -use zokrates_ast::typed::TypedAssemblyStatement; use zokrates_ast::typed::TypedAssignee; use zokrates_ast::typed::{ ArrayExpressionInner, ArrayType, BlockExpression, CoreIdentifier, Expr, FunctionCall, @@ -47,23 +47,6 @@ const MAX_FOR_LOOP_SIZE: u128 = 2u128.pow(20); pub type ConstantDefinitions<'ast, T> = HashMap, TypedExpression<'ast, T>>; -// An SSA version map, giving access to the latest version number for each identifier -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct Versions<'ast> { - map: HashMap, usize>>, -} - -impl<'ast> Versions<'ast> { - fn insert_in_frame( - &mut self, - id: CoreIdentifier<'ast>, - version: usize, - frame: usize, - ) -> Option { - self.map.entry(frame).or_default().insert(id, version) - } -} - #[derive(Debug, PartialEq, Eq)] pub enum Error { Incompatible(String), @@ -125,10 +108,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { &mut self, p: DeclarationParameter<'ast, T>, ) -> Result, Self::Error> { - // this is only used on the entry point - let id = p.id.id.id.id.clone(); - assert!(self.ssa.versions.insert_in_frame(id, 0, 0).is_none()); - Ok(p) + Ok(self.ssa.fold_parameter(p)) } fn fold_function_call_expression< @@ -138,34 +118,42 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ty: &E::Ty, e: FunctionCallExpression<'ast, T, E>, ) -> Result, Self::Error> { + // generics are already in ssa form + let generics = e .generics .into_iter() .map(|g| { g.map(|g| { - let g = self.ssa.fold_uint_expression(g); let g = self.propagator.fold_uint_expression(g)?; + let g = self.fold_uint_expression(g)?; - self.fold_uint_expression(g) + self.propagator + .fold_uint_expression(g) + .map_err(Self::Error::from) }) .transpose() }) .collect::>()?; + // arguments are already in ssa form + let arguments = e .arguments .into_iter() .map(|e| { - let e = self.ssa.fold_expression(e); let e = self.propagator.fold_expression(e)?; + let e = self.fold_expression(e)?; - self.fold_expression(e) + self.propagator + .fold_expression(e) + .map_err(Self::Error::from) }) .collect::>()?; self.ssa.push_call_frame(); - let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, &self.program); + let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); let res = match res { Ok((input_variables, arguments, generics_bindings, statements, expression)) => { @@ -183,7 +171,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { self.statement_buffer.extend(generics_bindings); - // the lhs is from the inner call frame, the rhs is from the outer one, so only fld the lhs + // the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs let input_bindings: Vec<_> = input_variables .into_iter() .zip(arguments) @@ -274,27 +262,6 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { }) } - fn fold_assembly_statement( - &mut self, - s: TypedAssemblyStatement<'ast, T>, - ) -> Result>, Self::Error> { - Ok(match s { - TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - self.fold_assignee(a)?, - self.fold_expression(e)?, - )] - } - TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { - vec![TypedAssemblyStatement::Constraint( - self.fold_field_expression(lhs)?, - self.fold_field_expression(rhs)?, - metadata, - )] - } - }) - } - fn fold_canonical_constant_identifier( &mut self, _: CanonicalConstantIdentifier<'ast>, @@ -307,28 +274,16 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { s: TypedStatement<'ast, T>, ) -> Result>, Self::Error> { let res = match s { - TypedStatement::Definition(a, rhs) => { - // usually we transform and then propagate - // for definitions we need special treatment: we transform and propagate the rhs (which can contain function calls) - // then we reduce the rhs to remove the function calls - // only then we transform and propagate the assignee - - let rhs = self.ssa.fold_definition_rhs(rhs); - let rhs = self.propagator.fold_definition_rhs(rhs)?; - let rhs = self.fold_definition_rhs(rhs)?; - - let a = self.ssa.fold_assignee(a); - - self.propagator - .fold_statement(TypedStatement::Definition(a, rhs))? - } TypedStatement::For(v, from, to, statements) => { let from = self.ssa.fold_uint_expression(from); let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from @@ -345,40 +300,37 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { .collect::, _>>()? .into_iter() .flatten() - .collect()), + .collect::>()), _ => Err(Error::NonConstant(format!( "Expected loop bounds to be constant, found {}..{}", from, to ))), }? } - TypedStatement::Return(e) => { - let e = self.ssa.fold_expression(e); - let e = self.propagator.fold_expression(e)?; - vec![TypedStatement::Return(self.fold_expression(e)?)] - } - TypedStatement::Assertion(e, error) => { - let e = self.ssa.fold_boolean_expression(e); - let e = self.propagator.fold_boolean_expression(e)?; + s => { + let statements = self.ssa.fold_statement(s); - vec![TypedStatement::Assertion( - self.fold_boolean_expression(e)?, - error, - )] + let statements = statements + .into_iter() + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + let statements = statements + .map(|s| fold_statement(self, s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + let statements = statements + .map(|s| self.propagator.fold_statement(s)) + .collect::, _>>()? + .into_iter() + .flatten(); + + statements.collect() } - s => self - .ssa - .fold_statement(s) - .into_iter() - .map(|s| self.propagator.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .map(|s| fold_statement(self, s)) - .collect::, _>>()? - .into_iter() - .flatten() - .collect(), }; Ok(self.statement_buffer.drain(..).chain(res).collect()) @@ -394,12 +346,17 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let array = self.ssa.fold_array_expression(array); let array = self.propagator.fold_array_expression(array)?; let array = self.fold_array_expression(array)?; + let array = self.propagator.fold_array_expression(array)?; + let from = self.ssa.fold_uint_expression(from); let from = self.propagator.fold_uint_expression(from)?; let from = self.fold_uint_expression(from)?; + let from = self.propagator.fold_uint_expression(from)?; + let to = self.ssa.fold_uint_expression(to); let to = self.propagator.fold_uint_expression(to)?; let to = self.fold_uint_expression(to)?; + let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { (UExpressionInner::Value(..), UExpressionInner::Value(..)) => { @@ -503,14 +460,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // a_1 = a_0; - // # PUSH CALL to foo - // a_3 := a_1; // input binding - // #RETURN_AT_INDEX_0_0 := a_3; - // # POP CALL - // a_2 = #RETURN_AT_INDEX_0_0; - // return a_2; + // def main(field a_f0_v0) -> field { + // a_f0_v1 = a_f0_v0; // redef + // a_f1_v0 = a_f0_v1; // input binding + // a_f0_v2 = a_f1_v0; // output binding + // return a_f0_v2; // } let foo: TypedFunction = TypedFunction { @@ -606,30 +560,13 @@ mod tests { Variable::field_element(Identifier::from("a").version(1)).into(), FieldElementExpression::identifier("a".into()).into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo").signature( - DeclarationSignature::new() - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - ), - GGenericsAssignment::default(), - ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(3)).into(), + Variable::field_element(Identifier::from("a").in_frame(1)).into(), FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), - TypedStatement::definition( - Variable::field_element(Identifier::from(CoreIdentifier::Call(0)).version(0)) - .into(), - FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::field_element(Identifier::from("a").version(2)).into(), - FieldElementExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .into(), + FieldElementExpression::identifier(Identifier::from("a").in_frame(1)).into(), ), TypedStatement::Return( FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), @@ -678,14 +615,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // field[1] b_0 = [42]; - // # PUSH CALL to foo::<1> - // a_0 = b_0; - // #RETURN_AT_INDEX_0_0 := a_0; - // # POP CALL - // b_1 = #RETURN_AT_INDEX_0_0; - // return a_2 + b_1[0]; + // def main(field a_f0_v0) -> field { + // field[1] b_f0_v0 = [a_f0_v0]; + // a_f1_v0 = b_f0_v0; + // b_f0_v1 = a_f1_v0; + // return a_f0_v0 + b_f0_v1[0]; // } let foo_signature = DeclarationSignature::new() @@ -812,42 +746,19 @@ mod tests { .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), TypedStatement::definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(0), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpression::identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), - ArrayExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .annotate(Type::FieldElement, 1u32) - .into(), + ArrayExpression::identifier(Identifier::from("a").in_frame(1)) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Return( (FieldElementExpression::identifier("a".into()) @@ -902,14 +813,11 @@ mod tests { // } // expected: - // def main(field a_0) -> field { - // field[1] b_0 = [42]; - // # PUSH CALL to foo::<1> - // a_0 = b_0; - // #RETURN_AT_INDEX_0_0 := a_0; - // # POP CALL - // b_1 = #RETURN_AT_INDEX_0_0; - // return a_2 + b_1[0]; + // def main(field a) -> field { + // field[1] b = [a]; + // a_f1 = b; + // b_1 = a_f1; + // return a + b_1[0]; // } let foo_signature = DeclarationSignature::new() @@ -1040,47 +948,25 @@ mod tests { TypedStatement::definition( Variable::array("b", Type::FieldElement, 1u32).into(), ArrayExpressionInner::Value( - vec![FieldElementExpression::identifier("a".into()).into()].into(), + vec![FieldElementExpression::identifier(Identifier::from("a")).into()] + .into(), ) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), TypedStatement::definition( - Variable::array(Identifier::from("a").version(1), Type::FieldElement, 1u32) + Variable::array(Identifier::from("a").in_frame(1), Type::FieldElement, 1u32) .into(), ArrayExpression::identifier("b".into()) .annotate(Type::FieldElement, 1u32) .into(), ), - TypedStatement::definition( - Variable::array( - Identifier::from(CoreIdentifier::Call(0)).version(0), - Type::FieldElement, - 1u32, - ) - .into(), - ArrayExpression::identifier(Identifier::from("a").version(1)) - .annotate(Type::FieldElement, 1u32) - .into(), - ), - TypedStatement::PopCallLog, TypedStatement::definition( Variable::array(Identifier::from("b").version(1), Type::FieldElement, 1u32) .into(), - ArrayExpression::identifier( - Identifier::from(CoreIdentifier::Call(0)).version(0), - ) - .annotate(Type::FieldElement, 1u32) - .into(), + ArrayExpression::identifier(Identifier::from("a").in_frame(1)) + .annotate(Type::FieldElement, 1u32) + .into(), ), TypedStatement::Return( (FieldElementExpression::identifier("a".into()) @@ -1306,33 +1192,11 @@ mod tests { let expected_main = TypedFunction { arguments: vec![], - statements: vec![ - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "foo") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - ), - TypedStatement::PushCallLog( - DeclarationFunctionKey::with_location("main", "bar") - .signature(foo_signature.clone()), - GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 2)] - .into_iter() - .collect(), - ), - ), - TypedStatement::PopCallLog, - TypedStatement::PopCallLog, - TypedStatement::Return( - TupleExpressionInner::Value(vec![]) - .annotate(TupleType::new(vec![])) - .into(), - ), - ], + statements: vec![TypedStatement::Return( + TupleExpressionInner::Value(vec![]) + .annotate(TupleType::new(vec![])) + .into(), + )], signature: DeclarationSignature::new(), }; diff --git a/zokrates_analysis/src/reducer/shallow_ssa.rs b/zokrates_analysis/src/reducer/shallow_ssa.rs index 0d4509506..aaec4fd56 100644 --- a/zokrates_analysis/src/reducer/shallow_ssa.rs +++ b/zokrates_analysis/src/reducer/shallow_ssa.rs @@ -1,7 +1,6 @@ -// The SSA transformation leaves gaps in the indices when it hits a for-loop, so that the body of the for-loop can -// modify the variables in scope. The state of the indices before all for-loops is returned to account for that possibility. -// Function calls are also left unvisited -// Saving the indices is not required for function calls, as they cannot modify their environment +// The SSA transformation +// * introduces new versions if and only if we are assigning to an identifier +// * does not visit the statements of loops // Example: // def main(field a) -> field { @@ -19,21 +18,34 @@ // u32 n_0 = 42; // a_1 = a_0 + 1; // field b_0 = foo(a_1); // we keep the function call as is -// # versions: {n: 0, a: 1, b: 0} // for u32 i_0 in 0..n_0 { // // we keep the loop body as is // } // return b_3; // we leave versions b_1 and b_2 to make b accessible and modifiable inside the for-loop // } +use std::collections::HashMap; + use zokrates_ast::typed::folder::*; -use zokrates_ast::typed::identifier::FrameIdentifier; use zokrates_ast::typed::*; use zokrates_field::Field; -use super::Versions; +// An SSA version map, giving access to the latest version number for each identifier +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Versions<'ast> { + map: HashMap, usize>>, +} + +impl<'ast> Default for Versions<'ast> { + fn default() -> Self { + // create a call frame at index 0 + Self { + map: vec![(0, Default::default())].into_iter().collect(), + } + } +} #[derive(Debug, Default)] pub struct ShallowTransformer<'ast> { @@ -45,14 +57,16 @@ pub struct ShallowTransformer<'ast> { impl<'ast> ShallowTransformer<'ast> { pub fn issue_next_identifier(&mut self, c_id: CoreIdentifier<'ast>) -> Identifier<'ast> { - let frame_versions = self.versions.map.entry(self.frame()).or_default(); + let frame = self.frame(); + + let frame_versions = self.versions.map.entry(frame).or_default(); let version = frame_versions .entry(c_id.clone()) .and_modify(|e| *e += 1) // if it was already declared, we increment .or_default(); // otherwise, we start from this version - Identifier::from(c_id).version(*version) + Identifier::from(c_id.in_frame(frame)).version(*version) } fn issue_next_ssa_variable(&mut self, v: Variable<'ast, T>) -> Variable<'ast, T> { @@ -81,43 +95,69 @@ impl<'ast> ShallowTransformer<'ast> { self.versions.map.remove(&frame); } - pub fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + // fold an assignee replacing by the latest version. This is necessary because the trait implementation increases the ssa version for identifiers, + // but this should not be applied recursively to complex assignees + fn fold_assignee_no_ssa_increase( + &mut self, + a: TypedAssignee<'ast, T>, + ) -> TypedAssignee<'ast, T> { match a { - TypedAssignee::Identifier(v) => { - let v = self.issue_next_ssa_variable(v); - TypedAssignee::Identifier(self.fold_variable(v)) + TypedAssignee::Identifier(v) => TypedAssignee::Identifier(self.fold_variable(v)), + TypedAssignee::Select(box a, box index) => TypedAssignee::Select( + box self.fold_assignee_no_ssa_increase(a), + box self.fold_uint_expression(index), + ), + TypedAssignee::Member(box s, m) => { + TypedAssignee::Member(box self.fold_assignee_no_ssa_increase(s), m) + } + TypedAssignee::Element(box s, index) => { + TypedAssignee::Element(box self.fold_assignee_no_ssa_increase(s), index) } - a => fold_assignee(self, a), } } } impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { - fn fold_assembly_statement( + fn fold_function(&mut self, f: TypedFunction<'ast, T>) -> TypedFunction<'ast, T> { + for g in &f.signature.generics { + let generic_parameter = match g.as_ref().unwrap() { + DeclarationConstant::Generic(g) => g, + _ => unreachable!(), + }; + let _ = self.issue_next_identifier(CoreIdentifier::from(generic_parameter.clone())); + } + + fold_function(self, f) + } + + fn fold_parameter( &mut self, - s: TypedAssemblyStatement<'ast, T>, - ) -> Vec> { - match s { - TypedAssemblyStatement::Assignment(a, e) => { - let e = self.fold_expression(e); - let a = self.fold_assignee(a); - vec![TypedAssemblyStatement::Assignment(a, e)] + p: DeclarationParameter<'ast, T>, + ) -> DeclarationParameter<'ast, T> { + DeclarationParameter { + id: DeclarationVariable { + id: self.issue_next_identifier(p.id.id.id.id), + ..p.id + }, + ..p + } + } + + fn fold_assignee(&mut self, a: TypedAssignee<'ast, T>) -> TypedAssignee<'ast, T> { + match a { + // create a new version for assignments to identifiers + TypedAssignee::Identifier(v) => { + let v = self.issue_next_ssa_variable(v); + TypedAssignee::Identifier(self.fold_variable(v)) } - s => fold_assembly_statement(self, s), + // otherwise, simply replace by the current version + a => self.fold_assignee_no_ssa_increase(a), } } + fn fold_statement(&mut self, s: TypedStatement<'ast, T>) -> Vec> { match s { - TypedStatement::Definition(a, DefinitionRhs::Expression(e)) => { - let e = self.fold_expression(e); - let a = self.fold_assignee(a); - vec![TypedStatement::definition(a, e)] - } - TypedStatement::Definition(assignee, DefinitionRhs::EmbedCall(embed_call)) => { - let embed_call = self.fold_embed_call(embed_call); - let assignee = self.fold_assignee(assignee); - vec![TypedStatement::embed_call_definition(assignee, embed_call)] - } + // only fold bounds of for loop statements TypedStatement::For(v, from, to, stats) => { let from = self.fold_uint_expression(from); let to = self.fold_uint_expression(to); @@ -127,6 +167,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { } } + // retrieve the latest version fn fold_name(&mut self, n: Identifier<'ast>) -> Identifier<'ast> { let version = self .versions @@ -137,13 +178,7 @@ impl<'ast, T: Field> Folder<'ast, T> for ShallowTransformer<'ast> { .cloned() .unwrap_or(0); - let id = FrameIdentifier { - frame: self.frame(), - ..n.id - }; - - let res = Identifier { version, id }; - res + n.in_frame(self.frame()).version(version) } } @@ -156,36 +191,57 @@ mod tests { use super::*; #[test] - fn detect_non_constant_bound() { - let loops: Vec> = vec![TypedStatement::For( - Variable::new("i", Type::Uint(UBitwidth::B32), false), - UExpression::identifier("i".into()).annotate(UBitwidth::B32), - 2u32.into(), - vec![], - )]; + fn ignore_loop_content() { + // field foo = 0 + // u32 i = 4; + // for u32 i in i..2 { + // foo = 5; + // } - let statements = loops; + // should be left unchanged, as we do not visit the loop content nor the index variable let f = TypedFunction { arguments: vec![], - signature: DeclarationSignature::new(), - statements, + statements: vec![ + TypedStatement::definition( + TypedAssignee::Identifier(Variable::field_element(Identifier::from("foo"))), + FieldElementExpression::Number(Bn128Field::from(4)).into(), + ), + TypedStatement::definition( + TypedAssignee::Identifier(Variable::uint( + Identifier::from("i"), + UBitwidth::B32, + )), + UExpression::from(0u32).into(), + ), + TypedStatement::For( + Variable::new("i", Type::Uint(UBitwidth::B32), false), + UExpression::identifier("i".into()).annotate(UBitwidth::B32), + 2u32.into(), + vec![TypedStatement::definition( + TypedAssignee::Identifier(Variable::field_element(Identifier::from( + "foo", + ))), + FieldElementExpression::Number(Bn128Field::from(5)).into(), + )], + ), + TypedStatement::Return( + TupleExpressionInner::Value(vec![]) + .annotate(TupleType::new(vec![])) + .into(), + ), + ], + signature: DeclarationSignature::default(), }; - match ShallowTransformer::transform( - f, - &ConcreteGenericsAssignment::default(), - &mut Versions::default(), - ) { - Output::Incomplete(..) => {} - _ => unreachable!(), - }; + let mut ssa = ShallowTransformer::default(); + + assert_eq!(ssa.fold_function(f.clone()), f); } #[test] fn definition() { - // field a - // a = 5 + // field a = 5 // a = 6 // a @@ -194,9 +250,7 @@ mod tests { // a_1 = 6 // a_1 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -236,17 +290,14 @@ mod tests { #[test] fn incremental_definition() { - // field a - // a = 5 + // field a = 5 // a = a + 1 // should be turned into // a_0 = 5 // a_1 = a_0 + 1 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -295,9 +346,7 @@ mod tests { // a_0 = 2 // a_1 = foo(a_0) - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::field_element("a")), @@ -356,9 +405,7 @@ mod tests { // a_0 = [1, 1] // a_0[1] = 2 - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let s = TypedStatement::definition( TypedAssignee::Identifier(Variable::array("a", Type::FieldElement, 2u32)), @@ -413,9 +460,7 @@ mod tests { // a_0 = [[0, 1], [2, 3]] // a_0 = [4, 5] - let mut versions = Versions::new(); - - let mut u = ShallowTransformer::with_versions(&mut versions); + let mut u = ShallowTransformer::default(); let array_of_array_ty = Type::array((Type::array((Type::FieldElement, 2u32)), 2u32)); @@ -510,10 +555,10 @@ mod tests { mod for_loop { use super::*; - use zokrates_ast::typed::types::GGenericsAssignment; + #[test] fn treat_loop() { - // def main(field a) -> field { + // def main(field a) -> field { // u32 n = 42; // n = n; // a = a; @@ -528,24 +573,21 @@ mod tests { // return a; // } - // When called with K := 1, expected: + // expected: // def main(field a_0) -> field { - // u32 K = 1; // u32 n_0 = 42; // n_1 = n_0; // a_1 = a_0; - // # versions: {n: 1, a: 1, K: 0} // for u32 i_0 in n_1..n_1*n_1 { // a_0 = a_0; // } - // a_3 = a_2; - // # versions: {n: 2, a: 3, K: 1} - // for u32 i_0 in n_2..n_2*n_2 { + // a_2 = a_1; + // for u32 i_0 in n_1..n_1*n_1 { // a_0 = a_0; // } - // a_5 = a_4; - // return a_5; - // } # versions: {n: 3, a: 5, K: 2} + // a_3 = a_2; + // return a_3; + // } let f: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], @@ -595,32 +637,15 @@ mod tests { TypedStatement::Return(FieldElementExpression::identifier("a".into()).into()), ], signature: DeclarationSignature::new() - .generics(vec![Some( - GenericIdentifier::with_name("K").with_index(0).into(), - )]) .inputs(vec![DeclarationType::FieldElement]) .output(DeclarationType::FieldElement), }; - let mut versions = Versions::default(); - - let ssa = ShallowTransformer::transform( - f, - &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - &mut versions, - ); + let mut ssa = ShallowTransformer::default(); let expected = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ - TypedStatement::definition( - Variable::uint("K", UBitwidth::B32).into(), - TypedExpression::Uint(1u32.into()), - ), TypedStatement::definition( Variable::uint("n", UBitwidth::B32).into(), TypedExpression::Uint(42u32.into()), @@ -649,16 +674,16 @@ mod tests { )], ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(3)).into(), - FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), + Variable::field_element(Identifier::from("a").version(2)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(1)).into(), ), TypedStatement::For( Variable::uint("i", UBitwidth::B32), - UExpression::identifier(Identifier::from("n").version(2)) + UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), - UExpression::identifier(Identifier::from("n").version(2)) + UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32) - * UExpression::identifier(Identifier::from("n").version(2)) + * UExpression::identifier(Identifier::from("n").version(1)) .annotate(UBitwidth::B32), vec![TypedStatement::definition( Variable::field_element("a").into(), @@ -666,47 +691,35 @@ mod tests { )], ), TypedStatement::definition( - Variable::field_element(Identifier::from("a").version(5)).into(), - FieldElementExpression::identifier(Identifier::from("a").version(4)).into(), + Variable::field_element(Identifier::from("a").version(3)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(2)).into(), ), TypedStatement::Return( - FieldElementExpression::identifier(Identifier::from("a").version(5)).into(), + FieldElementExpression::identifier(Identifier::from("a").version(3)).into(), ), ], signature: DeclarationSignature::new() - .generics(vec![Some( - GenericIdentifier::with_name("K").with_index(0).into(), - )]) .inputs(vec![DeclarationType::FieldElement]) .output(DeclarationType::FieldElement), }; - assert_eq!( - versions, - vec![("n".into(), 3), ("a".into(), 5), ("K".into(), 2)] - .into_iter() - .collect::() - ); + let res = ssa.fold_function(f); - let expected = Output::Incomplete( - expected, - vec![ - vec![("n".into(), 1), ("a".into(), 1), ("K".into(), 0)] - .into_iter() - .collect::(), - vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 1)] - .into_iter() - .collect::(), - ], + assert_eq!( + ssa.versions.map, + vec![( + 0, + vec![("n".into(), 1), ("a".into(), 3)].into_iter().collect() + )] + .into_iter() + .collect() ); - assert_eq!(ssa, expected); + assert_eq!(res, expected); } } mod shadowing { - use zokrates_ast::typed::types::GGenericsAssignment; - use super::*; #[test] @@ -717,11 +730,11 @@ mod tests { // return; // } - // should become + // should become (only the field variable is affected as shadowing is taken care of in semantics already) - // def main(field a_0) { - // field a_1 = 42; - // bool a_2 = true; + // def main(field a_s0_v0) { + // field a_s0_v1 = 42; + // bool a_s1_v0 = true // return; // } @@ -733,7 +746,11 @@ mod tests { TypedExpression::Uint(42u32.into()), ), TypedStatement::definition( - Variable::boolean("a").into(), + Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow( + "a".into(), + 1, + ))) + .into(), BooleanExpression::Value(true).into(), ), TypedStatement::Return( @@ -742,9 +759,7 @@ mod tests { .into(), ), ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]), + signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]), }; let expected: TypedFunction = TypedFunction { @@ -755,7 +770,11 @@ mod tests { TypedExpression::Uint(42u32.into()), ), TypedStatement::definition( - Variable::boolean(Identifier::from("a").version(2)).into(), + Variable::boolean(CoreIdentifier::from(ShadowedIdentifier::shadow( + "a".into(), + 1, + ))) + .into(), BooleanExpression::Value(true).into(), ), TypedStatement::Return( @@ -764,121 +783,17 @@ mod tests { .into(), ), ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]), - }; - - let mut versions = Versions::default(); - - let ssa = - ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions); - - assert_eq!(ssa, Output::Complete(expected)); - } - - #[test] - fn next_scope() { - // def main(field a) { - // for u32 i in 0..1 { - // a = a + 1 - // field a = 42 - // } - // return a - // } - - // should become - - // def main(field a_0) { - // # versions: {a: 0} - // for u32 i in 0..1 { - // a_0 = a_0 - // field a_0 = 42 - // } - // return a_1 - // } - - let f: TypedFunction = TypedFunction { - arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![ - TypedStatement::For( - Variable::uint("i", UBitwidth::B32), - 0u32.into(), - 1u32.into(), - vec![ - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::identifier("a".into()).into(), - ), - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::Number(42usize.into()).into(), - ), - ], - ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![FieldElementExpression::identifier( - "a".into(), - ) - .into()]) - .annotate(TupleType::new(vec![Type::FieldElement])) - .into(), - ), - ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), + signature: DeclarationSignature::new().inputs(vec![DeclarationType::FieldElement]), }; - let expected: TypedFunction = TypedFunction { - arguments: vec![DeclarationVariable::field_element("a").into()], - statements: vec![ - TypedStatement::For( - Variable::uint("i", UBitwidth::B32), - 0u32.into(), - 1u32.into(), - vec![ - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::identifier(Identifier::from("a")).into(), - ), - TypedStatement::definition( - Variable::field_element(Identifier::from("a")).into(), - FieldElementExpression::Number(42usize.into()).into(), - ), - ], - ), - TypedStatement::Return( - TupleExpressionInner::Value(vec![FieldElementExpression::identifier( - Identifier::from("a").version(1), - ) - .into()]) - .annotate(TupleType::new(vec![Type::FieldElement])) - .into(), - ), - ], - signature: DeclarationSignature::new() - .generics(vec![]) - .inputs(vec![DeclarationType::FieldElement]) - .output(DeclarationType::FieldElement), - }; - - let mut versions = Versions::default(); - - let ssa = - ShallowTransformer::transform(f, &GGenericsAssignment::default(), &mut versions); + let ssa = ShallowTransformer::default().fold_function(f); - assert_eq!( - ssa, - Output::Incomplete(expected, vec![vec![("a".into(), 0)].into_iter().collect()]) - ); + assert_eq!(ssa, expected); } } mod function_call { use super::*; - use zokrates_ast::typed::types::GGenericsAssignment; // test that function calls are left in #[test] fn treat_calls() { @@ -892,17 +807,12 @@ mod tests { // return a; // } - // When called with K := 1, expected: // def main(field a_0) -> field { - // K = 1; - // u32 n_0 = 42; - // n_1 = n_0; // a_1 = a_0; - // a_2 = foo::(a_1); - // n_2 = n_1; - // a_3 = a_2 * foo::(a_2); + // a_2 = foo::<42>(a_1); + // a_3 = a_2 * foo::<42>(a_2); // return a_3; - // } # versions: {n: 2, a: 3} + // } let f: TypedFunction = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], @@ -960,25 +870,9 @@ mod tests { .output(DeclarationType::FieldElement), }; - let mut versions = Versions::default(); - - let ssa = ShallowTransformer::transform( - f, - &GGenericsAssignment( - vec![(GenericIdentifier::with_name("K").with_index(0), 1)] - .into_iter() - .collect(), - ), - &mut versions, - ); - let expected = TypedFunction { arguments: vec![DeclarationVariable::field_element("a").into()], statements: vec![ - TypedStatement::definition( - Variable::uint("K", UBitwidth::B32).into(), - TypedExpression::Uint(1u32.into()), - ), TypedStatement::definition( Variable::uint("n", UBitwidth::B32).into(), TypedExpression::Uint(42u32.into()), @@ -1042,14 +936,23 @@ mod tests { .output(DeclarationType::FieldElement), }; + let mut ssa = ShallowTransformer::default(); + + let res = ssa.fold_function(f); + assert_eq!( - versions, - vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] - .into_iter() - .collect::() + ssa.versions.map, + vec![( + 0, + vec![("n".into(), 2), ("a".into(), 3), ("K".into(), 0)] + .into_iter() + .collect() + )] + .into_iter() + .collect() ); - assert_eq!(ssa, Output::Incomplete(expected, vec![],)); + assert_eq!(res, expected); } } } diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 1180874fd..989819fcf 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -531,10 +531,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: Folder<'ast, T>>( ) -> Vec> { match s { TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - f.fold_assignee(a), - f.fold_expression(e), - )] + let e = f.fold_expression(e); + vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a), e)] } TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { vec![TypedAssemblyStatement::Constraint( @@ -552,8 +550,9 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( ) -> Vec> { let res = match s { TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)), - TypedStatement::Definition(a, e) => { - TypedStatement::Definition(f.fold_assignee(a), f.fold_definition_rhs(e)) + TypedStatement::Definition(a, rhs) => { + let rhs = f.fold_definition_rhs(rhs); + TypedStatement::Definition(f.fold_assignee(a), rhs) } TypedStatement::Assertion(e, error) => { TypedStatement::Assertion(f.fold_boolean_expression(e), error) @@ -576,7 +575,6 @@ pub fn fold_statement<'ast, T: Field, F: Folder<'ast, T>>( .flat_map(|s| f.fold_assembly_statement(s)) .collect(), ), - s => s, }; vec![res] } diff --git a/zokrates_ast/src/typed/identifier.rs b/zokrates_ast/src/typed/identifier.rs index 91aaaa308..d23a9b3b7 100644 --- a/zokrates_ast/src/typed/identifier.rs +++ b/zokrates_ast/src/typed/identifier.rs @@ -24,6 +24,21 @@ impl<'ast> fmt::Display for CoreIdentifier<'ast> { } } +impl<'ast> FrameIdentifier<'ast> { + pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { + FrameIdentifier { frame, ..self } + } +} + +impl<'ast> Identifier<'ast> { + pub fn in_frame(self, frame: usize) -> Identifier<'ast> { + Identifier { + id: self.id.in_frame(frame), + ..self + } + } +} + impl<'ast> CoreIdentifier<'ast> { pub fn in_frame(self, frame: usize) -> FrameIdentifier<'ast> { FrameIdentifier { id: self, frame } diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index df8b0f638..1d4ad0fcf 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -27,7 +27,7 @@ pub use self::types::{ UBitwidth, }; use self::types::{ConcreteArrayType, ConcreteStructType}; -use crate::typed::types::{ConcreteGenericsAssignment, IntoType}; +use crate::typed::types::IntoType; pub use self::variable::{ConcreteVariable, DeclarationVariable, GVariable, Variable}; use std::marker::PhantomData; @@ -353,19 +353,8 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedFunction<'ast, T> { writeln!(f)?; - let mut tab = 0; - for s in &self.statements { - if let TypedStatement::PopCallLog = s { - tab -= 1; - }; - - s.fmt_indented(f, 1 + tab)?; - writeln!(f)?; - - if let TypedStatement::PushCallLog(..) = s { - tab += 1; - }; + writeln!(f, "{}", s)?; } writeln!(f, "}}")?; @@ -695,12 +684,6 @@ pub enum TypedStatement<'ast, T> { Vec>, ), Log(FormatString, Vec>), - // Aux - PushCallLog( - DeclarationFunctionKey<'ast, T>, - ConcreteGenericsAssignment<'ast>, - ), - PopCallLog, Assembly(Vec>), } @@ -714,31 +697,6 @@ impl<'ast, T> TypedStatement<'ast, T> { } } -impl<'ast, T: fmt::Display> TypedStatement<'ast, T> { - fn fmt_indented(&self, f: &mut fmt::Formatter, depth: usize) -> fmt::Result { - match self { - TypedStatement::For(variable, from, to, statements) => { - write!(f, "{}", "\t".repeat(depth))?; - writeln!(f, "for {} in {}..{} {{", variable, from, to)?; - for s in statements { - s.fmt_indented(f, depth + 1)?; - writeln!(f)?; - } - write!(f, "{}}}", "\t".repeat(depth)) - } - TypedStatement::Assembly(statements) => { - write!(f, "{}", "\t".repeat(depth))?; - writeln!(f, "asm {{")?; - for s in statements { - writeln!(f, "{}{}", "\t".repeat(depth + 1), s)?; - } - write!(f, "{}}}", "\t".repeat(depth)) - } - s => write!(f, "{}{}", "\t".repeat(depth), s), - } - } -} - impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { @@ -773,14 +731,6 @@ impl<'ast, T: fmt::Display> fmt::Display for TypedStatement<'ast, T> { .collect::>() .join(", ") ), - TypedStatement::PushCallLog(ref key, ref generics) => write!( - f, - "// PUSH CALL TO {}/{}::<{}>", - key.module.display(), - key.id, - generics, - ), - TypedStatement::PopCallLog => write!(f, "// POP CALL",), TypedStatement::Assembly(ref statements) => { writeln!(f, "asm {{")?; for s in statements { diff --git a/zokrates_ast/src/typed/result_folder.rs b/zokrates_ast/src/typed/result_folder.rs index 25c84c292..8ed911314 100644 --- a/zokrates_ast/src/typed/result_folder.rs +++ b/zokrates_ast/src/typed/result_folder.rs @@ -532,10 +532,8 @@ pub fn fold_assembly_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( ) -> Result>, F::Error> { Ok(match s { TypedAssemblyStatement::Assignment(a, e) => { - vec![TypedAssemblyStatement::Assignment( - f.fold_assignee(a)?, - f.fold_expression(e)?, - )] + let e = f.fold_expression(e)?; + vec![TypedAssemblyStatement::Assignment(f.fold_assignee(a)?, e)] } TypedAssemblyStatement::Constraint(lhs, rhs, metadata) => { vec![TypedAssemblyStatement::Constraint( @@ -554,7 +552,8 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( let res = match s { TypedStatement::Return(e) => TypedStatement::Return(f.fold_expression(e)?), TypedStatement::Definition(a, e) => { - TypedStatement::Definition(f.fold_assignee(a)?, f.fold_definition_rhs(e)?) + let rhs = f.fold_definition_rhs(e)?; + TypedStatement::Definition(f.fold_assignee(a)?, rhs) } TypedStatement::Assertion(e, error) => { TypedStatement::Assertion(f.fold_boolean_expression(e)?, error) @@ -586,7 +585,6 @@ pub fn fold_statement<'ast, T: Field, F: ResultFolder<'ast, T>>( .flatten() .collect(), ), - s => s, }; Ok(vec![res]) } diff --git a/zokrates_ast/src/typed/types.rs b/zokrates_ast/src/typed/types.rs index 60d3792fc..f2bf23d5b 100644 --- a/zokrates_ast/src/typed/types.rs +++ b/zokrates_ast/src/typed/types.rs @@ -240,9 +240,9 @@ impl<'ast, T> From for UExpression<'ast, T> { impl<'ast, T: Field> From> for UExpression<'ast, T> { fn from(c: DeclarationConstant<'ast, T>) -> Self { match c { - DeclarationConstant::Generic(_g) => { - // UExpression::identifier(FrameIdentifier::from(g).into()).annotate(UBitwidth::B32) - unreachable!() + DeclarationConstant::Generic(g) => { + UExpression::identifier(Identifier::from(CoreIdentifier::from(g))) + .annotate(UBitwidth::B32) } DeclarationConstant::Concrete(v) => { UExpressionInner::Value(v as u128).annotate(UBitwidth::B32) diff --git a/zokrates_core_test/tests/tests/call_ssa.json b/zokrates_core_test/tests/tests/call_ssa.json new file mode 100644 index 000000000..436754828 --- /dev/null +++ b/zokrates_core_test/tests/tests/call_ssa.json @@ -0,0 +1,16 @@ +{ + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": "4" + } + } + } + ] + } + \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/call_ssa.zok b/zokrates_core_test/tests/tests/call_ssa.zok new file mode 100644 index 000000000..dad61d190 --- /dev/null +++ b/zokrates_core_test/tests/tests/call_ssa.zok @@ -0,0 +1,11 @@ +// main should be x -> x + 4 + +def foo(field mut a) -> field { + a = a + 1; + return a + 1; +} + +def main(field mut a) -> field { + a = foo(a + 1); + return a + 1; +} \ No newline at end of file From c01cc25c5133d339df5109fb8cbf448d4d945d3d Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 22 Feb 2023 21:35:21 +0100 Subject: [PATCH 11/19] changelog, prettier --- changelogs/unreleased/1283-schaeff | 1 + zokrates_core_test/tests/tests/call_ssa.json | 25 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) create mode 100644 changelogs/unreleased/1283-schaeff diff --git a/changelogs/unreleased/1283-schaeff b/changelogs/unreleased/1283-schaeff new file mode 100644 index 000000000..f4cf39095 --- /dev/null +++ b/changelogs/unreleased/1283-schaeff @@ -0,0 +1 @@ +Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) \ No newline at end of file diff --git a/zokrates_core_test/tests/tests/call_ssa.json b/zokrates_core_test/tests/tests/call_ssa.json index 436754828..ae021f62d 100644 --- a/zokrates_core_test/tests/tests/call_ssa.json +++ b/zokrates_core_test/tests/tests/call_ssa.json @@ -1,16 +1,15 @@ { - "max_constraint_count": 1, - "tests": [ - { - "input": { - "values": ["0"] - }, - "output": { - "Ok": { - "value": "4" - } + "max_constraint_count": 1, + "tests": [ + { + "input": { + "values": ["0"] + }, + "output": { + "Ok": { + "value": "4" } } - ] - } - \ No newline at end of file + } + ] +} From 156ff243063b896b69f1337fa56fa2bebe562e63 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 22 Feb 2023 21:43:56 +0100 Subject: [PATCH 12/19] clean inliner --- zokrates_analysis/src/reducer/inline.rs | 40 ++++++------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index 002228f7d..edaa28032 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -15,16 +15,17 @@ // ``` // // Becomes -// ``` -// # Call foo::<42> with a_0 := x -// n_0 = 42 -// a_1 = a_0 -// n_1 = n_0 -// # Pop call with #CALL_RETURN_AT_INDEX_0_0 := a_1 +// inputs: [a] +// arguments: [x] +// generics_bindings: [n = 42] +// statements: +// n = 42 +// a = a +// n = n +// return_expression: a // Notes: -// - The body of the function is in SSA form -// - The return value(s) are assigned to internal variables +// - The body of the function is *not* in SSA form use zokrates_ast::common::FlatEmbed; use zokrates_ast::typed::types::{ConcreteGenericsAssignment, IntoType}; @@ -165,12 +166,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); - // let ssa_f = ShallowTransformer::transform(f, &assignment, versions); - - // let ssa_f = f; - - // let call_log = TypedStatement::PushCallLog(decl.key.clone(), assignment.clone()); - let generics_bindings: Vec<_> = assignment .0 .into_iter() @@ -205,23 +200,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( _ => unreachable!(), }; - // let v: ConcreteVariable<'ast> = ConcreteVariable::new( - // Identifier::from(CoreIdentifier::Call(0)).version( - // *versions - // .entry(CoreIdentifier::Call(0)) - // .and_modify(|e| *e += 1) // if it was already declared, we increment - // .or_insert(0), - // ), - // *inferred_signature.output.clone(), - // false, - // ); - - // let expression = TypedExpression::from(Variable::from(v.clone())); - - // let output_binding = TypedStatement::definition(Variable::from(v).into(), return_expression); - - // let pop_log = TypedStatement::PopCallLog; - Ok(( input_variables, arguments, From ba7aa6044c1e460eeb0650d2f5aaf6d3b01874b3 Mon Sep 17 00:00:00 2001 From: schaeff Date: Wed, 22 Feb 2023 21:51:58 +0100 Subject: [PATCH 13/19] clean --- zokrates_ast/src/typed/folder.rs | 2 +- zokrates_ast/src/typed/mod.rs | 42 ------------------- zokrates_core/src/imports.rs | 1 - .../tests/tests/uint/rotate.zok | 3 +- 4 files changed, 2 insertions(+), 46 deletions(-) diff --git a/zokrates_ast/src/typed/folder.rs b/zokrates_ast/src/typed/folder.rs index 989819fcf..722dcbf87 100644 --- a/zokrates_ast/src/typed/folder.rs +++ b/zokrates_ast/src/typed/folder.rs @@ -135,7 +135,7 @@ pub trait Folder<'ast, T: Field>: Sized { id: CoreIdentifier::Constant(self.fold_canonical_constant_identifier(c)), frame: 0, }, - _id => n.id, + _ => n.id, }; Identifier { id, ..n } diff --git a/zokrates_ast/src/typed/mod.rs b/zokrates_ast/src/typed/mod.rs index 1d4ad0fcf..bd000d12a 100644 --- a/zokrates_ast/src/typed/mod.rs +++ b/zokrates_ast/src/typed/mod.rs @@ -1253,48 +1253,6 @@ impl<'ast, T: Field> From> for FieldElementExpression<'as } } -impl<'ast, T: Field> From> for BooleanExpression<'ast, T> { - fn from(assignee: TypedAssignee<'ast, T>) -> Self { - match assignee { - TypedAssignee::Identifier(v) => BooleanExpression::identifier(v.id), - TypedAssignee::Element(box a, index) => BooleanExpression::element(a.into(), index), - TypedAssignee::Member(box a, id) => BooleanExpression::member(a.into(), id), - TypedAssignee::Select(box a, box index) => BooleanExpression::select(a.into(), index), - } - } -} - -impl<'ast, T: Field> From> for UExpression<'ast, T> { - fn from(assignee: TypedAssignee<'ast, T>) -> Self { - match assignee { - TypedAssignee::Identifier(v) => { - let inner = UExpression::identifier(v.id); - match v._type { - GType::Uint(bitwidth) => inner.annotate(bitwidth), - _ => unreachable!(), - } - } - TypedAssignee::Element(box a, index) => UExpression::element(a.into(), index), - TypedAssignee::Member(box a, id) => UExpression::member(a.into(), id), - TypedAssignee::Select(box a, box index) => UExpression::select(a.into(), index), - } - } -} - -impl<'ast, T: Field> From> for TypedExpression<'ast, T> { - fn from(assignee: TypedAssignee<'ast, T>) -> Self { - match assignee.get_type() { - Type::FieldElement => FieldElementExpression::from(assignee).into(), - Type::Boolean => BooleanExpression::from(assignee).into(), - Type::Struct(_) => StructExpression::from(assignee).into(), - Type::Array(_) => ArrayExpression::from(assignee).into(), - Type::Uint(_) => UExpression::from(assignee).into(), - Type::Tuple(_) => TupleExpression::from(assignee).into(), - Type::Int => unreachable!(), - } - } -} - impl<'ast, T> Add for FieldElementExpression<'ast, T> { type Output = Self; diff --git a/zokrates_core/src/imports.rs b/zokrates_core/src/imports.rs index d384aa646..f7fa95f13 100644 --- a/zokrates_core/src/imports.rs +++ b/zokrates_core/src/imports.rs @@ -79,7 +79,6 @@ impl Importer { .into_iter() .map(|s| match s.value.symbol { Symbol::Here(SymbolDefinition::Import(import)) => { - log::debug!("Resolve {} from {}", import, location.display()); Importer::resolve::(import, &location, resolver, modules, arena) } _ => Ok(s), diff --git a/zokrates_core_test/tests/tests/uint/rotate.zok b/zokrates_core_test/tests/tests/uint/rotate.zok index c593ae18e..c2ba820f9 100644 --- a/zokrates_core_test/tests/tests/uint/rotate.zok +++ b/zokrates_core_test/tests/tests/uint/rotate.zok @@ -3,8 +3,7 @@ import "utils/casts/u32_from_bits" as from_bits; def right_rotate(u32 e) -> u32 { bool[32] b = to_bits(e); - u32 res = from_bits([...b[32-N..], ...b[..32-N]]); - return res; + return from_bits([...b[32-N..], ...b[..32-N]]); } def main(u32 e) -> u32 { From 35c1a9686b9339c7bea09eef77e854231a5ce09f Mon Sep 17 00:00:00 2001 From: schaeff Date: Thu, 23 Feb 2023 11:26:57 +0100 Subject: [PATCH 14/19] error out if loops are too large --- zokrates_analysis/src/reducer/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index b2d5da5c6..79749d686 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -286,6 +286,11 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { let to = self.propagator.fold_uint_expression(to)?; match (from.as_inner(), to.as_inner()) { + (UExpressionInner::Value(from), UExpressionInner::Value(to)) + if to - from > MAX_FOR_LOOP_SIZE => + { + Err(Error::LoopTooLarge(to.saturating_sub(*from))) + } (UExpressionInner::Value(from), UExpressionInner::Value(to)) => Ok((*from ..*to) .flat_map(|index| { From c81cc66fd3322eca63ccb130e554667caf24793e Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 27 Feb 2023 13:35:50 +0100 Subject: [PATCH 15/19] simplify inliner --- zokrates_analysis/src/reducer/inline.rs | 101 ++++++++---------------- zokrates_analysis/src/reducer/mod.rs | 47 +++++------ 2 files changed, 53 insertions(+), 95 deletions(-) diff --git a/zokrates_analysis/src/reducer/inline.rs b/zokrates_analysis/src/reducer/inline.rs index edaa28032..8d3727d3d 100644 --- a/zokrates_analysis/src/reducer/inline.rs +++ b/zokrates_analysis/src/reducer/inline.rs @@ -42,18 +42,8 @@ use zokrates_field::Field; pub enum InlineError<'ast, T> { Generic(DeclarationFunctionKey<'ast, T>, ConcreteFunctionKey<'ast>), - Flat( - FlatEmbed, - Vec, - Vec>, - Type<'ast, T>, - ), - NonConstant( - DeclarationFunctionKey<'ast, T>, - Vec>>, - Vec>, - Type<'ast, T>, - ), + Flat(FlatEmbed, Vec, Type<'ast, T>), + NonConstant, } fn get_canonical_function<'ast, T: Field>( @@ -74,26 +64,26 @@ fn get_canonical_function<'ast, T: Field>( } } -type InlineResult<'ast, T> = Result< - ( - Vec>, - Vec>, - Vec>, - Vec>, - TypedExpression<'ast, T>, - ), - InlineError<'ast, T>, ->; +pub struct InlineValue<'ast, T> { + /// the pre-SSA input variables to assign the arguments to + pub input_variables: Vec>, + /// the pre-SSA statements for this call, including definition of the generic parameters + pub statements: Vec>, + /// the pre-SSA return value for this call + pub return_value: TypedExpression<'ast, T>, +} + +type InlineResult<'ast, T> = Result, InlineError<'ast, T>>; pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( k: &DeclarationFunctionKey<'ast, T>, - generics: Vec>>, - arguments: Vec>, - output: &E::Ty, + generics: &[Option>], + arguments: &[TypedExpression<'ast, T>], + output_ty: &E::Ty, program: &TypedProgram<'ast, T>, ) -> InlineResult<'ast, T> { use zokrates_ast::typed::Typed; - let output_type = output.clone().into_type(); + let output_type = output_ty.clone().into_type(); // we try to get concrete values for explicit generics let generics_values: Vec> = generics @@ -107,32 +97,19 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .transpose() }) .collect::>() - .map_err(|_| { - InlineError::NonConstant( - k.clone(), - generics.clone(), - arguments.clone(), - output_type.clone(), - ) - })?; + .map_err(|_| InlineError::NonConstant)?; // we infer a signature based on inputs and outputs - // this is where we could handle explicit annotations let inferred_signature = Signature::new() - .generics(generics.clone()) + .generics(generics.to_vec().clone()) .inputs(arguments.iter().map(|a| a.get_type()).collect()) .output(output_type.clone()); - // we try to get concrete values for the whole signature. if this fails we should propagate again + // we try to get concrete values for the whole signature let inferred_signature = match ConcreteSignature::try_from(inferred_signature) { Ok(s) => s, Err(_) => { - return Err(InlineError::NonConstant( - k.clone(), - generics, - arguments, - output_type, - )); + return Err(InlineError::NonConstant); } }; @@ -158,7 +135,6 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( TypedFunctionSymbol::Flat(e) => Err(InlineError::Flat( e, e.generics::(&assignment), - arguments.clone(), output_type, )), _ => unreachable!(), @@ -166,19 +142,15 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( assert_eq!(f.arguments.len(), arguments.len()); - let generics_bindings: Vec<_> = assignment - .0 - .into_iter() - .map(|(identifier, value)| { - TypedStatement::Definition( - TypedAssignee::Identifier(Variable::uint( - CoreIdentifier::from(identifier), - UBitwidth::B32, - )), - TypedExpression::from(UExpression::from(value)).into(), - ) - }) - .collect(); + let generic_bindings = assignment.0.into_iter().map(|(identifier, value)| { + TypedStatement::Definition( + TypedAssignee::Identifier(Variable::uint( + CoreIdentifier::from(identifier), + UBitwidth::B32, + )), + TypedExpression::from(UExpression::from(value)).into(), + ) + }); let input_variables: Vec> = f .arguments @@ -188,23 +160,20 @@ pub fn inline_call<'a, 'ast, T: Field, E: Expr<'ast, T>>( .map(Variable::from) .collect(); - let (statements, mut returns): (Vec<_>, Vec<_>) = f - .statements - .into_iter() + let (statements, mut returns): (Vec<_>, Vec<_>) = generic_bindings + .chain(f.statements) .partition(|s| !matches!(s, TypedStatement::Return(..))); assert_eq!(returns.len(), 1); - let return_expression = match returns.pop().unwrap() { + let return_value = match returns.pop().unwrap() { TypedStatement::Return(e) => e, _ => unreachable!(), }; - Ok(( + Ok(InlineValue { input_variables, - arguments, - generics_bindings, statements, - return_expression, - )) + return_value, + }) } diff --git a/zokrates_analysis/src/reducer/mod.rs b/zokrates_analysis/src/reducer/mod.rs index 79749d686..826cd6663 100644 --- a/zokrates_analysis/src/reducer/mod.rs +++ b/zokrates_analysis/src/reducer/mod.rs @@ -17,6 +17,7 @@ mod constants_writer; mod inline; mod shallow_ssa; +use self::inline::InlineValue; use self::inline::{inline_call, InlineError}; use std::collections::HashMap; use zokrates_ast::typed::result_folder::*; @@ -120,7 +121,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { ) -> Result, Self::Error> { // generics are already in ssa form - let generics = e + let generics: Vec<_> = e .generics .into_iter() .map(|g| { @@ -138,7 +139,7 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { // arguments are already in ssa form - let arguments = e + let arguments: Vec<_> = e .arguments .into_iter() .map(|e| { @@ -153,24 +154,14 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { self.ssa.push_call_frame(); - let res = inline_call::<_, E>(&e.function_key, generics, arguments, ty, self.program); + let res = inline_call::<_, E>(&e.function_key, &generics, &arguments, ty, self.program); let res = match res { - Ok((input_variables, arguments, generics_bindings, statements, expression)) => { - let generics_bindings = generics_bindings - .into_iter() - .flat_map(|s| self.ssa.fold_statement(s)) - .map(|s| self.propagator.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten() - .map(|s| self.fold_statement(s)) - .collect::, _>>()? - .into_iter() - .flatten(); - - self.statement_buffer.extend(generics_bindings); - + Ok(InlineValue { + input_variables, + statements, + return_value, + }) => { // the lhs is from the inner call frame, the rhs is from the outer one, so only fold the lhs let input_bindings: Vec<_> = input_variables .into_iter() @@ -196,27 +187,25 @@ impl<'ast, 'a, T: Field> ResultFolder<'ast, T> for Reducer<'ast, 'a, T> { self.statement_buffer.extend(statements); - let expression = self.ssa.fold_expression(expression); + let return_value = self.ssa.fold_expression(return_value); - let expression = self.propagator.fold_expression(expression)?; + let return_value = self.propagator.fold_expression(return_value)?; - let expression = self.fold_expression(expression)?; + let return_value = self.fold_expression(return_value)?; Ok(FunctionCallOrExpression::Expression( - E::from(expression).into_inner(), + E::from(return_value).into_inner(), )) } Err(InlineError::Generic(decl, conc)) => Err(Error::Incompatible(format!( "Call site `{}` incompatible with declaration `{}`", conc, decl ))), - Err(InlineError::NonConstant(key, generics, arguments, _)) => { - Err(Error::NonConstant(format!( - "Generic parameters must be compile-time constants, found {}", - FunctionCallExpression::<_, E>::new(key, generics, arguments) - ))) - } - Err(InlineError::Flat(embed, generics, arguments, output_type)) => { + Err(InlineError::NonConstant) => Err(Error::NonConstant(format!( + "Generic parameters must be compile-time constants, found {}", + FunctionCallExpression::<_, E>::new(e.function_key, generics, arguments) + ))), + Err(InlineError::Flat(embed, generics, output_type)) => { let identifier = self.ssa.issue_next_identifier(CoreIdentifier::Call(0)); let var = Variable::immutable(identifier.clone(), output_type); From 8ccf681f194b8ed231904a34dd2bd2711a1e169b Mon Sep 17 00:00:00 2001 From: schaeff Date: Mon, 27 Feb 2023 13:47:06 +0100 Subject: [PATCH 16/19] upgrade prettier action to 4.3 --- .github/workflows/js-format-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/js-format-check.yml b/.github/workflows/js-format-check.yml index f09fe5974..9c88c5cc9 100644 --- a/.github/workflows/js-format-check.yml +++ b/.github/workflows/js-format-check.yml @@ -6,6 +6,6 @@ jobs: steps: - uses: actions/checkout@v2 - name: Check format with prettier - uses: creyD/prettier_action@v4.2 + uses: creyD/prettier_action@v4.3 with: prettier_options: --check ./**/*.{js,ts,json} From 2a2370efbaee6078cdbd3bc87d6f41e5700b7567 Mon Sep 17 00:00:00 2001 From: dark64 Date: Tue, 28 Feb 2023 10:56:18 +0100 Subject: [PATCH 17/19] add changelog --- changelogs/unreleased/1280-dark64 | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelogs/unreleased/1280-dark64 diff --git a/changelogs/unreleased/1280-dark64 b/changelogs/unreleased/1280-dark64 new file mode 100644 index 000000000..e6392c713 --- /dev/null +++ b/changelogs/unreleased/1280-dark64 @@ -0,0 +1 @@ +Fix `radix-path` help message on `mpc init` subcommand \ No newline at end of file From 60ee62d3097e86704f01b41fc12c0910d07f9dc1 Mon Sep 17 00:00:00 2001 From: Ahmed Castro Date: Thu, 16 Mar 2023 19:01:46 -0600 Subject: [PATCH 18/19] Fixed small typo Found a small typo while testing this. Cheers! --- zokrates_cli/examples/sudoku/sudoku_checker.zok | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zokrates_cli/examples/sudoku/sudoku_checker.zok b/zokrates_cli/examples/sudoku/sudoku_checker.zok index 795e81ed0..d6b596f35 100644 --- a/zokrates_cli/examples/sudoku/sudoku_checker.zok +++ b/zokrates_cli/examples/sudoku/sudoku_checker.zok @@ -15,7 +15,7 @@ def countDuplicates(field e11, field e12, field e21, field e22) -> field { duplicates = duplicates + e11 == e21 ? 1 : 0; duplicates = duplicates + e11 == e22 ? 1 : 0; duplicates = duplicates + e12 == e21 ? 1 : 0; - duplicates = duplicates + e12 == e21 ? 1 : 0; + duplicates = duplicates + e12 == e22 ? 1 : 0; duplicates = duplicates + e21 == e22 ? 1 : 0; return duplicates; } From 16ef50fa9c51c1558d4dcd59e064b84b2d70e443 Mon Sep 17 00:00:00 2001 From: schaeff Date: Tue, 28 Mar 2023 10:45:53 +0200 Subject: [PATCH 19/19] bump versions, generate changelog --- CHANGELOG.md | 11 +++++++++++ Cargo.lock | 12 ++++++------ changelogs/unreleased/1275-dark64 | 1 - changelogs/unreleased/1277-dark64 | 1 - changelogs/unreleased/1280-dark64 | 1 - changelogs/unreleased/1283-schaeff | 1 - zokrates_analysis/Cargo.toml | 2 +- zokrates_ast/Cargo.toml | 2 +- zokrates_cli/Cargo.toml | 2 +- zokrates_core/Cargo.toml | 2 +- zokrates_interpreter/Cargo.toml | 2 +- zokrates_js/Cargo.toml | 2 +- zokrates_js/package.json | 2 +- 13 files changed, 24 insertions(+), 17 deletions(-) delete mode 100644 changelogs/unreleased/1275-dark64 delete mode 100644 changelogs/unreleased/1277-dark64 delete mode 100644 changelogs/unreleased/1280-dark64 delete mode 100644 changelogs/unreleased/1283-schaeff diff --git a/CHANGELOG.md b/CHANGELOG.md index 4188c1c34..9ea84c10a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file. ## [Unreleased] https://github.com/Zokrates/ZoKrates/compare/latest...develop +## [0.8.5] - 2023-03-28 + +### Release +- https://github.com/Zokrates/ZoKrates/releases/tag/0.8.5 + +### Changes +- Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) (#1283, @schaeff) +- Fix `radix-path` help message on `mpc init` subcommand (#1280, @dark64) +- Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair (#1277, @dark64) +- Show help when running `zokrates mpc` (#1275, @dark64) + ## [0.8.4] - 2023-01-31 ### Release diff --git a/Cargo.lock b/Cargo.lock index e21917722..e780f0273 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2906,7 +2906,7 @@ dependencies = [ [[package]] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" dependencies = [ "cfg-if 0.1.10", "csv", @@ -2956,7 +2956,7 @@ dependencies = [ [[package]] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" dependencies = [ "ark-bls12-377", "cfg-if 0.1.10", @@ -3004,7 +3004,7 @@ dependencies = [ [[package]] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" dependencies = [ "assert_cli", "blake2 0.8.1", @@ -3064,7 +3064,7 @@ dependencies = [ [[package]] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" dependencies = [ "cfg-if 0.1.10", "csv", @@ -3147,7 +3147,7 @@ dependencies = [ [[package]] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" dependencies = [ "ark-bls12-377", "num", @@ -3163,7 +3163,7 @@ dependencies = [ [[package]] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" dependencies = [ "console_error_panic_hook", "getrandom", diff --git a/changelogs/unreleased/1275-dark64 b/changelogs/unreleased/1275-dark64 deleted file mode 100644 index 506dce652..000000000 --- a/changelogs/unreleased/1275-dark64 +++ /dev/null @@ -1 +0,0 @@ -Show help when running `zokrates mpc` \ No newline at end of file diff --git a/changelogs/unreleased/1277-dark64 b/changelogs/unreleased/1277-dark64 deleted file mode 100644 index e94ff6a55..000000000 --- a/changelogs/unreleased/1277-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix a potential crash in `zokrates-js` due to inefficient serialization of a setup keypair diff --git a/changelogs/unreleased/1280-dark64 b/changelogs/unreleased/1280-dark64 deleted file mode 100644 index e6392c713..000000000 --- a/changelogs/unreleased/1280-dark64 +++ /dev/null @@ -1 +0,0 @@ -Fix `radix-path` help message on `mpc init` subcommand \ No newline at end of file diff --git a/changelogs/unreleased/1283-schaeff b/changelogs/unreleased/1283-schaeff deleted file mode 100644 index f4cf39095..000000000 --- a/changelogs/unreleased/1283-schaeff +++ /dev/null @@ -1 +0,0 @@ -Reduce memory usage and runtime by refactoring the reducer (ssa, propagation, unrolling and inlining) \ No newline at end of file diff --git a/zokrates_analysis/Cargo.toml b/zokrates_analysis/Cargo.toml index e347abcdf..f93dc7d8d 100644 --- a/zokrates_analysis/Cargo.toml +++ b/zokrates_analysis/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_analysis" -version = "0.1.0" +version = "0.1.1" edition = "2021" [features] diff --git a/zokrates_ast/Cargo.toml b/zokrates_ast/Cargo.toml index 6d9b4324b..60eb498c6 100644 --- a/zokrates_ast/Cargo.toml +++ b/zokrates_ast/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_ast" -version = "0.1.4" +version = "0.1.5" edition = "2021" [features] diff --git a/zokrates_cli/Cargo.toml b/zokrates_cli/Cargo.toml index 3fa5b5feb..ccf98794d 100644 --- a/zokrates_cli/Cargo.toml +++ b/zokrates_cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_cli" -version = "0.8.4" +version = "0.8.5" authors = ["Jacob Eberhardt ", "Dennis Kuhnert ", "Thibaut Schaeffer "] repository = "https://github.com/Zokrates/ZoKrates.git" edition = "2018" diff --git a/zokrates_core/Cargo.toml b/zokrates_core/Cargo.toml index aa1964c52..2b2530047 100644 --- a/zokrates_core/Cargo.toml +++ b/zokrates_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_core" -version = "0.7.3" +version = "0.7.4" edition = "2021" authors = ["Jacob Eberhardt ", "Dennis Kuhnert "] repository = "https://github.com/Zokrates/ZoKrates" diff --git a/zokrates_interpreter/Cargo.toml b/zokrates_interpreter/Cargo.toml index 41cdfa353..270436685 100644 --- a/zokrates_interpreter/Cargo.toml +++ b/zokrates_interpreter/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_interpreter" -version = "0.1.2" +version = "0.1.3" edition = "2021" [features] diff --git a/zokrates_js/Cargo.toml b/zokrates_js/Cargo.toml index 0c86329b9..02374289b 100644 --- a/zokrates_js/Cargo.toml +++ b/zokrates_js/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zokrates_js" -version = "1.1.5" +version = "1.1.6" authors = ["Darko Macesic"] edition = "2018" diff --git a/zokrates_js/package.json b/zokrates_js/package.json index 7617f20f3..02b6e919e 100644 --- a/zokrates_js/package.json +++ b/zokrates_js/package.json @@ -1,6 +1,6 @@ { "name": "zokrates-js", - "version": "1.1.5", + "version": "1.1.6", "module": "index.js", "main": "index-node.js", "description": "JavaScript bindings for ZoKrates",