Skip to content

Commit

Permalink
Merge pull request iden3#64 from iden3/plonk-custom-gates
Browse files Browse the repository at this point in the history
Plonk custom gates
  • Loading branch information
clararod9 committed Jun 8, 2022
2 parents e8e606a + aedd8f5 commit 6b5cd01
Show file tree
Hide file tree
Showing 15 changed files with 574 additions and 99 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ parser/target/
program_structure/target/
type_analysis/target/
.idea/
.vscode
.DS_Store
Cargo.lock
8 changes: 4 additions & 4 deletions circom/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ fn start() -> Result<(), ()> {
c_flag: user_input.c_flag(),
wasm_flag: user_input.wasm_flag(),
wat_flag: user_input.wat_flag(),
js_folder: user_input.js_folder().to_string(),
wasm_name: user_input.wasm_name().to_string(),
c_folder: user_input.c_folder().to_string(),
c_run_name: user_input.c_run_name().to_string(),
js_folder: user_input.js_folder().to_string(),
wasm_name: user_input.wasm_name().to_string(),
c_folder: user_input.c_folder().to_string(),
c_run_name: user_input.c_run_name().to_string(),
c_file: user_input.c_file().to_string(),
dat_file: user_input.dat_file().to_string(),
wat_file: user_input.wat_file().to_string(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct ComponentRepresentation {
unassigned_inputs: HashMap<String, SliceCapacity>,
inputs: HashMap<String, SignalSlice>,
outputs: HashMap<String, SignalSlice>,
pub is_custom_gate: bool,
}

impl Default for ComponentRepresentation {
Expand All @@ -17,6 +18,7 @@ impl Default for ComponentRepresentation {
unassigned_inputs: HashMap::new(),
inputs: HashMap::new(),
outputs: HashMap::new(),
is_custom_gate: false,
}
}
}
Expand All @@ -27,6 +29,7 @@ impl Clone for ComponentRepresentation {
unassigned_inputs: self.unassigned_inputs.clone(),
inputs: self.inputs.clone(),
outputs: self.outputs.clone(),
is_custom_gate: self.is_custom_gate,
}
}
}
Expand All @@ -49,9 +52,8 @@ impl ComponentRepresentation {
for (symbol, route) in node.inputs() {
let signal_slice = SignalSlice::new_with_route(route, &false);
let signal_slice_size = SignalSlice::get_number_of_cells(&signal_slice);
if signal_slice_size > 0{
unassigned_inputs
.insert(symbol.clone(), signal_slice_size);
if signal_slice_size > 0 {
unassigned_inputs.insert(symbol.clone(), signal_slice_size);
}
inputs.insert(symbol.clone(), signal_slice);
}
Expand All @@ -65,9 +67,11 @@ impl ComponentRepresentation {
unassigned_inputs,
inputs,
outputs,
is_custom_gate: node.is_custom_gate,
};
Result::Ok(())
}

pub fn signal_has_value(
component: &ComponentRepresentation,
signal_name: &str,
Expand All @@ -88,6 +92,7 @@ impl ComponentRepresentation {
let enabled = *SignalSlice::get_reference_to_single_value(slice, access)?;
Result::Ok(enabled)
}

pub fn get_signal(&self, signal_name: &str) -> Result<&SignalSlice, MemoryError> {
if self.node_pointer.is_none() {
return Result::Err(MemoryError::InvalidAccess);
Expand Down
117 changes: 97 additions & 20 deletions constraint_generation/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::{
ast::*, ArithmeticError, FileID, ProgramArchive, Report, ReportCode, ReportCollection
};
use circom_algebra::num_bigint::BigInt;
use std::collections::BTreeMap;

type AExpr = ArithmeticExpressionGen<String>;

#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
Expand Down Expand Up @@ -57,6 +57,7 @@ impl RuntimeInformation {
struct FoldedValue {
pub arithmetic_slice: Option<AExpressionSlice>,
pub node_pointer: Option<NodePointer>,
pub custom_gate_name: Option<String>,
}
impl FoldedValue {
pub fn valid_arithmetic_slice(f_value: &FoldedValue) -> bool {
Expand All @@ -69,7 +70,11 @@ impl FoldedValue {

impl Default for FoldedValue {
fn default() -> Self {
FoldedValue { arithmetic_slice: Option::None, node_pointer: Option::None }
FoldedValue {
arithmetic_slice: Option::None,
node_pointer: Option::None,
custom_gate_name: Option::None
}
}
}

Expand Down Expand Up @@ -191,10 +196,22 @@ fn execute_statement(
Option::None
}
Substitution { meta, var, access, op, rhe, .. } => {
let access_information = treat_accessing(meta, access, program_archive, runtime, flag_verbose)?;
let access_information = treat_accessing(
meta,
access,
program_archive,
runtime,
flag_verbose
)?;
let r_folded = execute_expression(rhe, program_archive, runtime, flag_verbose)?;
let possible_constraint =
perform_assign(meta, var, &access_information, r_folded, actual_node, runtime)?;
let possible_constraint = perform_assign(
meta,
var,
&access_information,
r_folded,
actual_node,
runtime
)?;
if let (Option::Some(node), AssignOp::AssignConstraintSignal) = (actual_node, op) {
debug_assert!(possible_constraint.is_some());
let constrained = possible_constraint.unwrap();
Expand All @@ -211,7 +228,40 @@ fn execute_statement(
let symbol = AExpr::Signal { symbol: constrained.left };
let expr = AExpr::sub(&symbol, &constrained.right, &p);
let ctr = AExpr::transform_expression_to_constraint_form(expr, &p).unwrap();
node.add_constraint(ctr);
if constrained.custom_gate_name.is_none() {
node.add_constraint(ctr);
} else {
// From a previous semantic analysis we know that in this case we must
// have that constrained.right is an AExpr::Signal, so we can safely unwrap
// the name of the signal in the right hand side of the expression.
debug_assert!(matches!(symbol, AExpr::Signal {..}));
debug_assert!(matches!(constrained.right, AExpr::Signal {..}));
if let AExpr::Signal { symbol: left } = symbol {
if let AExpr::Signal { symbol: right } = constrained.right {
let custom_gate_name = constrained.custom_gate_name.unwrap();
fn reorder(
left: String,
right: String,
custom_gate_name: &String
) -> (String, String) {
if left.starts_with(custom_gate_name) {
(left, right)
} else {
debug_assert!(right.starts_with(custom_gate_name));
(right, left)
}
}

// Assignment of the form left <== right
let (inner, outer) = reorder(left, right, &custom_gate_name);
node.treat_custom_gate_constraint(custom_gate_name, inner, outer);
} else {
unreachable!();
}
} else {
unreachable!();
}
}
}
}
Option::None
Expand Down Expand Up @@ -488,6 +538,7 @@ fn execute_signal_declaration(
) {
use SignalType::*;
if let Option::Some(node) = actual_node {
node.add_ordered_signal(signal_name, dimensions);
match signal_type {
Input => {
environment_shortcut_add_input(environment, signal_name, dimensions);
Expand All @@ -514,6 +565,7 @@ fn execute_signal_declaration(
struct Constrained {
left: String,
right: AExpr,
custom_gate_name: Option<String>,
}
fn perform_assign(
meta: &Meta,
Expand All @@ -526,9 +578,8 @@ fn perform_assign(
use super::execution_data::type_definitions::SubComponentData;
let environment = &mut runtime.environment;
let full_symbol = create_symbol(symbol, &accessing_information);

let possible_arithmetic_expression = if ExecutionEnvironment::has_variable(environment, symbol)
{
let possible_custom_gate_name = r_folded.custom_gate_name.clone();
let possible_arithmetic_expression = if ExecutionEnvironment::has_variable(environment, symbol) { // review!
debug_assert!(accessing_information.signal_access.is_none());
debug_assert!(accessing_information.after_signal.is_empty());
let environment_result = ExecutionEnvironment::get_mut_variable_mut(environment, symbol);
Expand Down Expand Up @@ -605,7 +656,7 @@ fn perform_assign(
&mut runtime.runtime_errors,
&runtime.call_trace,
)?;
Option::Some(safe_unwrap_to_single_arithmetic_expression(r_folded, line!()))
Option::Some((safe_unwrap_to_single_arithmetic_expression(r_folded, line!()), None))
} else if ExecutionEnvironment::has_component(environment, symbol) {
let environment_response = ExecutionEnvironment::get_mut_component_res(environment, symbol);
let component_slice = treat_result_with_environment_error(
Expand Down Expand Up @@ -665,14 +716,28 @@ fn perform_assign(
&mut runtime.runtime_errors,
&runtime.call_trace,
)?;
Option::Some(AExpressionSlice::unwrap_to_single(arithmetic_slice))
let custom_gate = if component.is_custom_gate { Some(symbol.to_string()) } else { None };
Option::Some((AExpressionSlice::unwrap_to_single(arithmetic_slice), custom_gate))
}
} else {
unreachable!();
};
if let Option::Some(arithmetic_expression) = possible_arithmetic_expression {
let ret = Constrained { left: full_symbol, right: arithmetic_expression };
Result::Ok(Some(ret))
if let Option::Some((arithmetic_expression, custom_gate_name)) = possible_arithmetic_expression {
if custom_gate_name.is_none() {
let ret = Constrained {
left: full_symbol,
right: arithmetic_expression,
custom_gate_name: possible_custom_gate_name,
};
Result::Ok(Some(ret))
} else {
let ret = Constrained {
left: full_symbol,
right: arithmetic_expression,
custom_gate_name
};
Result::Ok(Some(ret))
}
} else {
Result::Ok(None)
}
Expand Down Expand Up @@ -909,6 +974,11 @@ fn execute_component(
&mut runtime.runtime_errors,
&runtime.call_trace,
)?;
let custom_gate_name = if checked_component.is_custom_gate {
Some(symbol.to_string())
} else {
None
};
if let Option::Some(signal_name) = &access_information.signal_access {
let access_after_signal = &access_information.after_signal;
let signal = treat_result_with_memory_error(
Expand All @@ -925,8 +995,13 @@ fn execute_component(
&runtime.call_trace,
)?;
let symbol = create_symbol(symbol, &access_information);
let result = signal_to_arith(symbol, slice)
.map(|s| FoldedValue { arithmetic_slice: Option::Some(s), ..FoldedValue::default() });
let result = signal_to_arith(symbol, slice).map(|s|
FoldedValue {
arithmetic_slice: Option::Some(s),
custom_gate_name,
..FoldedValue::default()
}
);
treat_result_with_memory_error(
result,
meta,
Expand All @@ -936,6 +1011,7 @@ fn execute_component(
} else {
Result::Ok(FoldedValue {
node_pointer: checked_component.node_pointer,
custom_gate_name,
..FoldedValue::default()
})
}
Expand Down Expand Up @@ -988,14 +1064,14 @@ fn execute_template_call(
debug_assert!(runtime.block_type == BlockType::Known);
let is_main = std::mem::replace(&mut runtime.public_inputs, vec![]);
let is_parallel = program_archive.get_template_data(id).is_parallel();
let is_custom_gate = program_archive.get_template_data(id).is_custom_gate();
let args_names = program_archive.get_template_data(id).get_name_of_params();
let template_body = program_archive.get_template_data(id).get_body_as_vec();
let mut args_to_values = BTreeMap::new();
debug_assert_eq!(args_names.len(), parameter_values.len());
let mut args_to_values = vec![];
let mut instantiation_name = format!("{}(", id);
for (name, value) in args_names.iter().zip(parameter_values) {
instantiation_name.push_str(&format!("{},", value.to_string()));
args_to_values.insert(name.clone(), value.clone());
args_to_values.push((name.clone(), value.clone()));
}
if !parameter_values.is_empty() {
instantiation_name.pop();
Expand All @@ -1014,7 +1090,8 @@ fn execute_template_call(
instantiation_name,
args_to_values,
code,
is_parallel
is_parallel,
is_custom_gate,
));
let ret = execute_sequence_of_statements(
template_body,
Expand Down
Loading

0 comments on commit 6b5cd01

Please sign in to comment.