Skip to content

Commit

Permalink
Merge pull request #28392 from ProvableHQ/ifelse
Browse files Browse the repository at this point in the history
Improve handling of guards while flattening.
  • Loading branch information
d0cd authored Oct 18, 2024
2 parents 76b8654 + 5e7591f commit 86783da
Show file tree
Hide file tree
Showing 55 changed files with 740 additions and 356 deletions.
28 changes: 22 additions & 6 deletions compiler/passes/src/flattening/flatten_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::Flattener;
use crate::{Flattener, ReturnGuard};

use leo_ast::{Function, ProgramReconstructor, ProgramScope, Statement, StatementReconstructor};
use leo_ast::{
Expression,
Function,
ProgramReconstructor,
ProgramScope,
ReturnStatement,
Statement,
StatementReconstructor,
};

impl ProgramReconstructor for Flattener<'_> {
/// Flattens a program scope.
Expand Down Expand Up @@ -47,11 +55,19 @@ impl ProgramReconstructor for Flattener<'_> {
// Flatten the function body.
let mut block = self.reconstruct_block(function.block).0;

// Get all of the guards and return expression.
let returns = self.clear_early_returns();

// Fold the return statements into the block.
self.fold_returns(&mut block, returns);
let returns = std::mem::take(&mut self.returns);
let expression_returns: Vec<(Option<Expression>, ReturnStatement)> = returns
.into_iter()
.map(|(guard, statement)| match guard {
ReturnGuard::None => (None, statement),
ReturnGuard::Unconstructed(plain) | ReturnGuard::Constructed { plain, .. } => {
(Some(Expression::Identifier(plain)), statement)
}
})
.collect();

self.fold_returns(&mut block, expression_returns);

Function {
annotations: function.annotations,
Expand Down
204 changes: 126 additions & 78 deletions compiler/passes/src/flattening/flatten_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::Flattener;
use crate::{Flattener, Guard, ReturnGuard};

use leo_ast::{
AssertStatement,
Expand All @@ -28,6 +28,7 @@ use leo_ast::{
DefinitionStatement,
Expression,
ExpressionReconstructor,
Identifier,
IterationStatement,
Node,
ReturnStatement,
Expand Down Expand Up @@ -93,77 +94,98 @@ impl StatementReconstructor for Flattener<'_> {
},
};

// Add the appropriate guards.
match self.construct_guard() {
// If the condition stack is empty, we can return the flattened assert statement.
None => (Statement::Assert(assert), statements),
// Otherwise, we need to join the guard with the expression in the flattened assert statement.
// Note given the guard and the expression, we construct the logical formula `guard => expression`,
// which is equivalent to `!guard || expression`.
Some(guard) => (
Statement::Assert(AssertStatement {
span: input.span,
id: input.id,
variant: AssertVariant::Assert(Expression::Binary(BinaryExpression {
op: BinaryOperation::Or,
span: Default::default(),
id: {
// Create a new node ID for the binary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the binary expression.
self.type_table.insert(id, Type::Boolean);
id
},
// Take the logical negation of the guard.
left: Box::new(Expression::Unary(UnaryExpression {
op: UnaryOperation::Not,
receiver: Box::new(guard),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
})),
right: Box::new(match assert.variant {
// If the assert statement is an `assert`, use the expression as is.
AssertVariant::Assert(expression) => expression,
// If the assert statement is an `assert_eq`, construct a new equality expression.
AssertVariant::AssertEq(left, right) => Expression::Binary(BinaryExpression {
left: Box::new(left),
op: BinaryOperation::Eq,
right: Box::new(right),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
}),
// If the assert statement is an `assert_ne`, construct a new inequality expression.
AssertVariant::AssertNeq(left, right) => Expression::Binary(BinaryExpression {
left: Box::new(left),
op: BinaryOperation::Neq,
right: Box::new(right),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
}),
}),
})),
}),
statements,
),
let mut guards: Vec<Expression> = Vec::new();

if let Some((guard, guard_statements)) = self.construct_guard() {
statements.extend(guard_statements);

// The not_guard is true if we didn't follow the condition chain
// that led to this assertion.
let not_guard = Expression::Unary(UnaryExpression {
op: UnaryOperation::Not,
receiver: Box::new(Expression::Identifier(guard)),
span: Default::default(),
id: {
// Create a new node ID for the unary expression.
let id = self.node_builder.next_id();
// Update the type table with the type of the unary expression.
self.type_table.insert(id, Type::Boolean);
id
},
});
let (identifier, statement) = self.unique_simple_assign_statement(not_guard);
statements.push(statement);
guards.push(Expression::Identifier(identifier));
}

// We also need to guard against early returns.
if let Some((guard, guard_statements)) = self.construct_early_return_guard() {
guards.push(Expression::Identifier(guard));
statements.extend(guard_statements);
}

if guards.is_empty() {
return (Statement::Assert(assert), statements);
}

let is_eq = matches!(assert.variant, AssertVariant::AssertEq(..));

// We need to `or` the asserted expression with the guards,
// so extract an appropriate expression.
let mut expression = match assert.variant {
// If the assert statement is an `assert`, use the expression as is.
AssertVariant::Assert(expression) => expression,

// For `assert_eq` or `assert_neq`, construct a new expression.
AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
let binary = Expression::Binary(BinaryExpression {
left: Box::new(left),
op: if is_eq { BinaryOperation::Eq } else { BinaryOperation::Neq },
right: Box::new(right),
span: Default::default(),
id: {
// Create a new node ID.
let id = self.node_builder.next_id();
// Update the type table.
self.type_table.insert(id, Type::Boolean);
id
},
});
let (identifier, statement) = self.unique_simple_assign_statement(binary);
statements.push(statement);
Expression::Identifier(identifier)
}
};

// The assertion will be that the original assert statement is true or one of the guards is true
// (ie, we either didn't follow the condition chain that led to this assert, or else we took an
// early return).
for guard in guards.into_iter() {
let binary = Expression::Binary(BinaryExpression {
op: BinaryOperation::Or,
span: Default::default(),
id: {
// Create a new node ID.
let id = self.node_builder.next_id();
// Update the type table.
self.type_table.insert(id, Type::Boolean);
id
},
left: Box::new(expression),
right: Box::new(guard),
});
let (identifier, statement) = self.unique_simple_assign_statement(binary);
statements.push(statement);
expression = Expression::Identifier(identifier);
}

let assert_statement = Statement::Assert(AssertStatement {
span: input.span,
id: input.id,
variant: AssertVariant::Assert(expression),
});

(assert_statement, statements)
}

/// Flattens an assign statement, if necessary.
Expand Down Expand Up @@ -250,8 +272,21 @@ impl StatementReconstructor for Flattener<'_> {
);
}

// Assign the condition to a variable, as it may be used multiple times.
let place = Identifier {
name: self.assigner.unique_symbol("condition", "$"),
span: Default::default(),
id: {
let id = self.node_builder.next_id();
self.type_table.insert(id, Type::Boolean);
id
},
};

statements.push(self.simple_assign_statement(place, conditional.condition.clone()));

// Add condition to the condition stack.
self.condition_stack.push(conditional.condition.clone());
self.condition_stack.push(Guard::Unconstructed(place));

// Reconstruct the then-block and accumulate it constituent statements.
statements.extend(self.reconstruct_block(conditional.then).0.statements);
Expand All @@ -261,13 +296,24 @@ impl StatementReconstructor for Flattener<'_> {

// Consume the otherwise-block and flatten its constituent statements into the current block.
if let Some(statement) = conditional.otherwise {
// Add the negated condition to the condition stack.
self.condition_stack.push(Expression::Unary(UnaryExpression {
// Apply Not to the condition, assign it, and put it on the condition stack.
let not_condition = Expression::Unary(UnaryExpression {
op: UnaryOperation::Not,
receiver: Box::new(conditional.condition.clone()),
span: conditional.condition.span(),
id: conditional.condition.id(),
}));
});
let not_place = Identifier {
name: self.assigner.unique_symbol("condition", "$"),
span: Default::default(),
id: {
let id = self.node_builder.next_id();
self.type_table.insert(id, Type::Boolean);
id
},
};
statements.push(self.simple_assign_statement(not_place, not_condition));
self.condition_stack.push(Guard::Unconstructed(not_place));

// Reconstruct the otherwise-block and accumulate it constituent statements.
match *statement {
Expand Down Expand Up @@ -302,15 +348,17 @@ impl StatementReconstructor for Flattener<'_> {
return (Statement::Return(input), Default::default());
}
// Construct the associated guard.
let guard = self.construct_guard();
let (guard_identifier, statements) = self.construct_guard().unzip();

let return_guard = guard_identifier.map_or(ReturnGuard::None, ReturnGuard::Unconstructed);

match input.expression {
Expression::Unit(_) | Expression::Identifier(_) | Expression::Access(_) => {
self.returns.push((guard, input))
self.returns.push((return_guard, input))
}
_ => unreachable!("SSA guarantees that the expression is always an identifier or unit expression."),
};

(Statement::dummy(Default::default(), self.node_builder.next_id()), Default::default())
(Statement::dummy(Default::default(), self.node_builder.next_id()), statements.unwrap_or_default())
}
}
Loading

0 comments on commit 86783da

Please sign in to comment.