diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 306fcdf3b..2e9e67708 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -28,37 +28,44 @@ impl Display for Analyzed { for statement in &self.source_order { match statement { StatementIdentifier::Definition(name) => { - let (symbol, definition) = &self.definitions[name]; - let (name, is_local) = update_namespace(name, symbol.degree, f)?; - match symbol.kind { - SymbolKind::Poly(poly_type) => { - let kind = match &poly_type { - PolynomialType::Committed => "witness ", - PolynomialType::Constant => "fixed ", - PolynomialType::Intermediate => "", - }; - write!(f, " col {kind}{name}")?; - if let Some(value) = definition { - writeln!(f, "{value};")? - } else { - writeln!(f, ";")? + if let Some((symbol, definition)) = self.definitions.get(name) { + let (name, is_local) = update_namespace(name, symbol.degree, f)?; + match symbol.kind { + SymbolKind::Poly(poly_type) => { + let kind = match &poly_type { + PolynomialType::Committed => "witness ", + PolynomialType::Constant => "fixed ", + PolynomialType::Intermediate => panic!(), + }; + write!(f, " col {kind}{name}")?; + if let Some(value) = definition { + writeln!(f, "{value};")? + } else { + writeln!(f, ";")? + } } - } - SymbolKind::Constant() => { - let indentation = if is_local { " " } else { "" }; - writeln!( - f, - "{indentation}constant {name}{};", - definition.as_ref().unwrap() - )?; - } - SymbolKind::Other() => { - write!(f, " let {name}")?; - if let Some(value) = definition { - write!(f, "{value}")? + SymbolKind::Constant() => { + let indentation = if is_local { " " } else { "" }; + writeln!( + f, + "{indentation}constant {name}{};", + definition.as_ref().unwrap() + )?; + } + SymbolKind::Other() => { + write!(f, " let {name}")?; + if let Some(value) = definition { + write!(f, "{value}")? + } + writeln!(f, ";")? } - writeln!(f, ";")? } + } else if let Some((symbol, definition)) = self.intermediate_columns.get(name) { + let (name, _) = update_namespace(name, symbol.degree, f)?; + assert_eq!(symbol.kind, SymbolKind::Poly(PolynomialType::Intermediate)); + writeln!(f, " col {name} = {definition};")?; + } else { + panic!() } } StatementIdentifier::PublicDeclaration(name) => { diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index d79edb914..d977b54f2 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -17,6 +17,7 @@ use crate::parsed::{self, SelectedExpressions}; #[derive(Debug)] pub enum StatementIdentifier { + /// Either an intermediate column or a definition. Definition(String), PublicDeclaration(String), /// Index into the vector of identities. @@ -27,6 +28,7 @@ pub enum StatementIdentifier { pub struct Analyzed { pub definitions: HashMap>)>, pub public_declarations: HashMap, + pub intermediate_columns: HashMap)>, pub identities: Vec>>, /// The order in which definitions and identities /// appear in the source. @@ -40,7 +42,7 @@ impl Analyzed { } /// @returns the number of intermediate polynomials (with multiplicities for arrays) pub fn intermediate_count(&self) -> usize { - self.declaration_type_count(PolynomialType::Intermediate) + self.intermediate_columns.len() } /// @returns the number of constant polynomials (with multiplicities for arrays) pub fn constant_count(&self) -> usize { @@ -59,26 +61,39 @@ impl Analyzed { self.definitions_in_source_order(PolynomialType::Committed) } - pub fn intermediate_polys_in_source_order( - &self, - ) -> Vec<&(Symbol, Option>)> { - self.definitions_in_source_order(PolynomialType::Intermediate) + pub fn intermediate_polys_in_source_order(&self) -> Vec<&(Symbol, Expression)> { + self.source_order + .iter() + .filter_map(move |statement| { + if let StatementIdentifier::Definition(name) = statement { + if let Some(definition) = self.intermediate_columns.get(name) { + return Some(definition); + } + } + None + }) + .collect() } pub fn definitions_in_source_order( &self, poly_type: PolynomialType, ) -> Vec<&(Symbol, Option>)> { + assert!( + poly_type != PolynomialType::Intermediate, + "Use intermediate_polys_in_source_order to get intermediate polys." + ); self.source_order .iter() .filter_map(move |statement| { if let StatementIdentifier::Definition(name) = statement { - let definition = &self.definitions[name]; - match definition.0.kind { - SymbolKind::Poly(ptype) if ptype == poly_type => { - return Some(definition); + if let Some(definition) = self.definitions.get(name) { + match definition.0.kind { + SymbolKind::Poly(ptype) if ptype == poly_type => { + return Some(definition); + } + _ => {} } - _ => {} } } None @@ -102,12 +117,11 @@ impl Analyzed { /// so that they are contiguous again. /// There must not be any reference to the removed polynomials left. pub fn remove_polynomials(&mut self, to_remove: &BTreeSet) { - let replacements: BTreeMap = [ + let mut replacements: BTreeMap = [ // We have to do it separately because we need to re-start the counter // for each kind. self.committed_polys_in_source_order(), self.constant_polys_in_source_order(), - self.intermediate_polys_in_source_order(), ] .map(|polys| { polys @@ -137,6 +151,13 @@ impl Analyzed { .flatten() .collect(); + // We assume for now that intermediate columns are not removed. + for (poly, _) in self.intermediate_columns.values() { + let poly_id: PolyID = poly.into(); + assert!(!to_remove.contains(&poly_id)); + replacements.insert(poly_id, poly_id); + } + let mut names_to_remove: HashSet = Default::default(); self.definitions.retain(|name, (poly, _def)| { if matches!(poly.kind, SymbolKind::Poly(_)) @@ -235,6 +256,9 @@ impl Analyzed { self.identities .iter_mut() .for_each(|i| i.post_visit_expressions_mut(f)); + self.intermediate_columns + .values_mut() + .for_each(|(_sym, value)| value.post_visit_expressions_mut(f)); } pub fn post_visit_expressions_in_definitions_mut(&mut self, f: &mut F) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index d9f10651e..e51b6d94b 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -22,3 +22,4 @@ pilopt = { path = "../pilopt" } mktemp = "0.5.0" test-log = "0.2.12" env_logger = "0.10.0" +pretty_assertions = "1.4.0" \ No newline at end of file diff --git a/backend/src/pilstark/json_exporter/expression_counter.rs b/backend/src/pilstark/json_exporter/expression_counter.rs index af161c53d..86784af07 100644 --- a/backend/src/pilstark/json_exporter/expression_counter.rs +++ b/backend/src/pilstark/json_exporter/expression_counter.rs @@ -15,11 +15,16 @@ pub fn compute_intermediate_expression_ids(analyzed: &Analyzed) -> HashMap for item in &analyzed.source_order { expression_counter += match item { StatementIdentifier::Definition(name) => { - let poly = &analyzed.definitions[name].0; - if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) { + if let Some((poly, _)) = analyzed.definitions.get(name) { + assert!(poly.kind != SymbolKind::Poly(PolynomialType::Intermediate)); + poly.expression_count() + } else if let Some((poly, _)) = analyzed.intermediate_columns.get(name) { + assert!(poly.kind == SymbolKind::Poly(PolynomialType::Intermediate)); ids.insert(poly.id, expression_counter as u64); + poly.expression_count() + } else { + unreachable!() } - poly.expression_count() } StatementIdentifier::PublicDeclaration(name) => { analyzed.public_declarations[name].expression_count() diff --git a/backend/src/pilstark/json_exporter/mod.rs b/backend/src/pilstark/json_exporter/mod.rs index 76fd48573..a24b7e27c 100644 --- a/backend/src/pilstark/json_exporter/mod.rs +++ b/backend/src/pilstark/json_exporter/mod.rs @@ -3,8 +3,8 @@ use std::cmp; use std::collections::HashMap; use ast::analyzed::{ - self, Analyzed, BinaryOperator, Expression, FunctionValueDefinition, IdentityKind, PolyID, - PolynomialReference, PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator, + self, Analyzed, BinaryOperator, Expression, IdentityKind, PolyID, PolynomialReference, + PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator, }; use starky::types::{ ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity, @@ -46,18 +46,13 @@ pub fn export(analyzed: &Analyzed) -> PIL { for item in &analyzed.source_order { match item { StatementIdentifier::Definition(name) => { - if let (poly, Some(value)) = &analyzed.definitions[name] { - if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) { - if let FunctionValueDefinition::Expression(value) = value { - let expression_id = exporter.extract_expression(value, 1); - assert_eq!( - expression_id, - exporter.intermediate_poly_expression_ids[&poly.id] as usize - ); - } else { - panic!("Expected single value"); - } - } + if let Some((poly, value)) = analyzed.intermediate_columns.get(name) { + assert_eq!(poly.kind, SymbolKind::Poly(PolynomialType::Intermediate)); + let expression_id = exporter.extract_expression(value, 1); + assert_eq!( + expression_id, + exporter.intermediate_poly_expression_ids[&poly.id] as usize + ); } } StatementIdentifier::PublicDeclaration(name) => { @@ -187,7 +182,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> { .filter_map(|(name, (symbol, _value))| { let id = match symbol.kind { SymbolKind::Poly(PolynomialType::Intermediate) => { - Some(self.intermediate_poly_expression_ids[&symbol.id]) + panic!("Should be in intermediates") } SymbolKind::Poly(_) => Some(symbol.id), SymbolKind::Other() | SymbolKind::Constant() => None, @@ -204,6 +199,26 @@ impl<'a, T: FieldElement> Exporter<'a, T> { }; Some((name.clone(), out)) }) + .chain( + self.analyzed + .intermediate_columns + .iter() + .map(|(name, (symbol, _))| { + assert_eq!(symbol.kind, SymbolKind::Poly(PolynomialType::Intermediate)); + let id = self.intermediate_poly_expression_ids[&symbol.id]; + + let out = Reference { + polType: None, + type_: symbol_kind_to_json_string(symbol.kind).to_string(), + id: id as usize, + polDeg: symbol.degree as usize, + isArray: symbol.is_array(), + elementType: None, + len: symbol.length.map(|l| l as usize), + }; + (name.clone(), out) + }), + ) .collect::>() } @@ -368,6 +383,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> { #[cfg(test)] mod test { use pil_analyzer::analyze; + use pretty_assertions::assert_eq; use serde_json::Value as JsonValue; use std::{fs, process::Command}; use test_log::test; diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 6bf3cd6c6..f6c0e7f69 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -42,6 +42,7 @@ fn generate_values( .map(|i| { Evaluator { definitions: &analyzed.definitions, + intermediate_columns: &analyzed.intermediate_columns, variables: &[i.into()], function_cache: other_constants, } @@ -52,6 +53,7 @@ fn generate_values( FunctionValueDefinition::Array(values) => { let evaluator = Evaluator { definitions: &analyzed.definitions, + intermediate_columns: &analyzed.intermediate_columns, variables: &[], function_cache: other_constants, }; diff --git a/pil_analyzer/src/condenser.rs b/pil_analyzer/src/condenser.rs index 47f97db9f..b79bbea50 100644 --- a/pil_analyzer/src/condenser.rs +++ b/pil_analyzer/src/condenser.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use ast::{ analyzed::{ Analyzed, Expression, FunctionValueDefinition, Identity, PolynomialReference, - PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind, + PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind, }, evaluate_binary_operation, evaluate_unary_operation, parsed::{visitor::ExpressionVisitable, SelectedExpressions}, @@ -22,7 +22,7 @@ pub fn condense( source_order: Vec, ) -> Analyzed { let condenser = Condenser { - constants: compute_constants(&definitions), + constants: compute_constants(&definitions, &Default::default()), symbols: definitions .iter() .map(|(name, (symbol, _))| (name.clone(), symbol.clone())) @@ -34,6 +34,25 @@ pub fn condense( .map(|identity| condenser.condense_identity(identity)) .collect(); + // Extract intermediate columns + let intermediate_columns: HashMap<_, _> = definitions + .iter() + .filter_map(|(name, (symbol, definition))| { + if matches!(symbol.kind, SymbolKind::Poly(PolynomialType::Intermediate)) { + let Some(FunctionValueDefinition::Expression(e)) = definition else { + panic!("Expected expression") + }; + Some(( + name.clone(), + (symbol.clone(), condenser.condense_expression(e.clone())), + )) + } else { + None + } + }) + .collect(); + definitions.retain(|name, _| !intermediate_columns.contains_key(name)); + definitions.values_mut().for_each(|(_, definition)| { if let Some(def) = definition { def.post_visit_expressions_mut(&mut |e| { @@ -50,6 +69,7 @@ pub fn condense( Analyzed { definitions, public_declarations, + intermediate_columns, identities, source_order, } diff --git a/pil_analyzer/src/evaluator.rs b/pil_analyzer/src/evaluator.rs index 8a3db5e07..57572afdf 100644 --- a/pil_analyzer/src/evaluator.rs +++ b/pil_analyzer/src/evaluator.rs @@ -10,10 +10,12 @@ use number::FieldElement; /// Evaluates an expression to a single value. pub fn evaluate_expression( definitions: &HashMap>)>, + intermediate_columns: &HashMap)>, expression: &Expression, ) -> Result { Evaluator { definitions, + intermediate_columns, function_cache: &Default::default(), variables: &[], } @@ -23,6 +25,7 @@ pub fn evaluate_expression( /// Returns a HashMap of all symbols that have a constant single value. pub fn compute_constants( definitions: &HashMap>)>, + intermediate_columns: &HashMap)>, ) -> HashMap { definitions .iter() @@ -33,7 +36,7 @@ pub fn compute_constants( }; ( name.to_owned(), - evaluate_expression(definitions, value).unwrap(), + evaluate_expression(definitions, intermediate_columns, value).unwrap(), ) }) }) @@ -42,6 +45,7 @@ pub fn compute_constants( pub struct Evaluator<'a, T> { pub definitions: &'a HashMap>)>, + pub intermediate_columns: &'a HashMap)>, /// Contains full value tables of functions (columns) we already evaluated. pub function_cache: &'a HashMap<&'a str, Vec>, pub variables: &'a [T], @@ -53,10 +57,19 @@ impl<'a, T: FieldElement> Evaluator<'a, T> { Expression::Reference(Reference::LocalVar(i, _name)) => Ok(self.variables[*i as usize]), Expression::Reference(Reference::Poly(poly)) => { if !poly.next && poly.index.is_none() { - let (_, value) = &self.definitions[&poly.name.to_string()]; - match value { - Some(FunctionValueDefinition::Expression(value)) => self.evaluate(value), - _ => Err("Cannot evaluate function-typed values".to_string()), + if let Some((_, value)) = self.definitions.get(&poly.name.to_string()) { + match value { + Some(FunctionValueDefinition::Expression(value)) => { + self.evaluate(value) + } + _ => Err("Cannot evaluate function-typed values".to_string()), + } + } else if let Some((_, value)) = + self.intermediate_columns.get(&poly.name.to_string()) + { + self.evaluate(value) + } else { + unreachable!() } } else { Err("Cannot evaluate arrays or next references.".to_string()) diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 59ad164e2..61de42107 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -459,6 +459,7 @@ impl PILAnalyzer { fn evaluate_expression(&self, expr: ::ast::parsed::Expression) -> Result { Evaluator { definitions: &self.definitions, + intermediate_columns: &Default::default(), function_cache: &Default::default(), variables: &[], } @@ -655,17 +656,9 @@ pub fn inline_intermediate_polynomials( substitute_intermediate( analyzed.identities.clone(), &analyzed - .definitions_in_source_order(PolynomialType::Intermediate) + .intermediate_polys_in_source_order() .iter() - .map(|(symbol, def)| { - ( - symbol.id, - match def.as_ref().unwrap() { - FunctionValueDefinition::Expression(e) => e.clone(), - _ => unreachable!(), - }, - ) - }) + .map(|(symbol, def)| (symbol.id, def.clone())) .collect(), ) } diff --git a/pilopt/src/lib.rs b/pilopt/src/lib.rs index 256d1e4b3..5f950bd22 100644 --- a/pilopt/src/lib.rs +++ b/pilopt/src/lib.rs @@ -46,7 +46,7 @@ pub fn optimize_constants(mut pil_file: Analyzed) -> Analyze /// Inlines references to symbols with a single constant value. fn inline_constant_values(pil_file: &mut Analyzed) { - let constants = compute_constants(&pil_file.definitions); + let constants = compute_constants(&pil_file.definitions, &pil_file.intermediate_columns); let visitor = &mut |e: &mut Expression<_>| { if let Expression::Reference(Reference::Poly(poly)) = e { if !poly.next && poly.index.is_none() {