Skip to content

Commit

Permalink
Merge pull request powdr-labs#703 from powdr-labs/single_out_intermed…
Browse files Browse the repository at this point in the history
…iates

Treat intermediate columns specially.
  • Loading branch information
Leo authored Oct 16, 2023
2 parents b14ba2e + 164880d commit 7721ec8
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 76 deletions.
63 changes: 35 additions & 28 deletions ast/src/analyzed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,37 +28,44 @@ impl<T: Display> Display for Analyzed<T> {
for statement in &self.source_order {
match statement {
StatementIdentifier::Definition(name) => {
let (symbol, definition) = &self.definitions[name];
let (name, is_local) = update_namespace(name, symbol.degree, f)?;
match symbol.kind {
SymbolKind::Poly(poly_type) => {
let kind = match &poly_type {
PolynomialType::Committed => "witness ",
PolynomialType::Constant => "fixed ",
PolynomialType::Intermediate => "",
};
write!(f, " col {kind}{name}")?;
if let Some(value) = definition {
writeln!(f, "{value};")?
} else {
writeln!(f, ";")?
if let Some((symbol, definition)) = self.definitions.get(name) {
let (name, is_local) = update_namespace(name, symbol.degree, f)?;
match symbol.kind {
SymbolKind::Poly(poly_type) => {
let kind = match &poly_type {
PolynomialType::Committed => "witness ",
PolynomialType::Constant => "fixed ",
PolynomialType::Intermediate => panic!(),
};
write!(f, " col {kind}{name}")?;
if let Some(value) = definition {
writeln!(f, "{value};")?
} else {
writeln!(f, ";")?
}
}
}
SymbolKind::Constant() => {
let indentation = if is_local { " " } else { "" };
writeln!(
f,
"{indentation}constant {name}{};",
definition.as_ref().unwrap()
)?;
}
SymbolKind::Other() => {
write!(f, " let {name}")?;
if let Some(value) = definition {
write!(f, "{value}")?
SymbolKind::Constant() => {
let indentation = if is_local { " " } else { "" };
writeln!(
f,
"{indentation}constant {name}{};",
definition.as_ref().unwrap()
)?;
}
SymbolKind::Other() => {
write!(f, " let {name}")?;
if let Some(value) = definition {
write!(f, "{value}")?
}
writeln!(f, ";")?
}
writeln!(f, ";")?
}
} else if let Some((symbol, definition)) = self.intermediate_columns.get(name) {
let (name, _) = update_namespace(name, symbol.degree, f)?;
assert_eq!(symbol.kind, SymbolKind::Poly(PolynomialType::Intermediate));
writeln!(f, " col {name} = {definition};")?;
} else {
panic!()
}
}
StatementIdentifier::PublicDeclaration(name) => {
Expand Down
48 changes: 36 additions & 12 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::parsed::{self, SelectedExpressions};

#[derive(Debug)]
pub enum StatementIdentifier {
/// Either an intermediate column or a definition.
Definition(String),
PublicDeclaration(String),
/// Index into the vector of identities.
Expand All @@ -27,6 +28,7 @@ pub enum StatementIdentifier {
pub struct Analyzed<T> {
pub definitions: HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
pub public_declarations: HashMap<String, PublicDeclaration>,
pub intermediate_columns: HashMap<String, (Symbol, Expression<T>)>,
pub identities: Vec<Identity<Expression<T>>>,
/// The order in which definitions and identities
/// appear in the source.
Expand All @@ -40,7 +42,7 @@ impl<T> Analyzed<T> {
}
/// @returns the number of intermediate polynomials (with multiplicities for arrays)
pub fn intermediate_count(&self) -> usize {
self.declaration_type_count(PolynomialType::Intermediate)
self.intermediate_columns.len()
}
/// @returns the number of constant polynomials (with multiplicities for arrays)
pub fn constant_count(&self) -> usize {
Expand All @@ -59,26 +61,39 @@ impl<T> Analyzed<T> {
self.definitions_in_source_order(PolynomialType::Committed)
}

pub fn intermediate_polys_in_source_order(
&self,
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
self.definitions_in_source_order(PolynomialType::Intermediate)
pub fn intermediate_polys_in_source_order(&self) -> Vec<&(Symbol, Expression<T>)> {
self.source_order
.iter()
.filter_map(move |statement| {
if let StatementIdentifier::Definition(name) = statement {
if let Some(definition) = self.intermediate_columns.get(name) {
return Some(definition);
}
}
None
})
.collect()
}

pub fn definitions_in_source_order(
&self,
poly_type: PolynomialType,
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
assert!(
poly_type != PolynomialType::Intermediate,
"Use intermediate_polys_in_source_order to get intermediate polys."
);
self.source_order
.iter()
.filter_map(move |statement| {
if let StatementIdentifier::Definition(name) = statement {
let definition = &self.definitions[name];
match definition.0.kind {
SymbolKind::Poly(ptype) if ptype == poly_type => {
return Some(definition);
if let Some(definition) = self.definitions.get(name) {
match definition.0.kind {
SymbolKind::Poly(ptype) if ptype == poly_type => {
return Some(definition);
}
_ => {}
}
_ => {}
}
}
None
Expand All @@ -102,12 +117,11 @@ impl<T> Analyzed<T> {
/// so that they are contiguous again.
/// There must not be any reference to the removed polynomials left.
pub fn remove_polynomials(&mut self, to_remove: &BTreeSet<PolyID>) {
let replacements: BTreeMap<PolyID, PolyID> = [
let mut replacements: BTreeMap<PolyID, PolyID> = [
// We have to do it separately because we need to re-start the counter
// for each kind.
self.committed_polys_in_source_order(),
self.constant_polys_in_source_order(),
self.intermediate_polys_in_source_order(),
]
.map(|polys| {
polys
Expand Down Expand Up @@ -137,6 +151,13 @@ impl<T> Analyzed<T> {
.flatten()
.collect();

// We assume for now that intermediate columns are not removed.
for (poly, _) in self.intermediate_columns.values() {
let poly_id: PolyID = poly.into();
assert!(!to_remove.contains(&poly_id));
replacements.insert(poly_id, poly_id);
}

let mut names_to_remove: HashSet<String> = Default::default();
self.definitions.retain(|name, (poly, _def)| {
if matches!(poly.kind, SymbolKind::Poly(_))
Expand Down Expand Up @@ -235,6 +256,9 @@ impl<T> Analyzed<T> {
self.identities
.iter_mut()
.for_each(|i| i.post_visit_expressions_mut(f));
self.intermediate_columns
.values_mut()
.for_each(|(_sym, value)| value.post_visit_expressions_mut(f));
}

pub fn post_visit_expressions_in_definitions_mut<F>(&mut self, f: &mut F)
Expand Down
1 change: 1 addition & 0 deletions backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ pilopt = { path = "../pilopt" }
mktemp = "0.5.0"
test-log = "0.2.12"
env_logger = "0.10.0"
pretty_assertions = "1.4.0"
11 changes: 8 additions & 3 deletions backend/src/pilstark/json_exporter/expression_counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@ pub fn compute_intermediate_expression_ids<T>(analyzed: &Analyzed<T>) -> HashMap
for item in &analyzed.source_order {
expression_counter += match item {
StatementIdentifier::Definition(name) => {
let poly = &analyzed.definitions[name].0;
if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) {
if let Some((poly, _)) = analyzed.definitions.get(name) {
assert!(poly.kind != SymbolKind::Poly(PolynomialType::Intermediate));
poly.expression_count()
} else if let Some((poly, _)) = analyzed.intermediate_columns.get(name) {
assert!(poly.kind == SymbolKind::Poly(PolynomialType::Intermediate));
ids.insert(poly.id, expression_counter as u64);
poly.expression_count()
} else {
unreachable!()
}
poly.expression_count()
}
StatementIdentifier::PublicDeclaration(name) => {
analyzed.public_declarations[name].expression_count()
Expand Down
46 changes: 31 additions & 15 deletions backend/src/pilstark/json_exporter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::cmp;
use std::collections::HashMap;

use ast::analyzed::{
self, Analyzed, BinaryOperator, Expression, FunctionValueDefinition, IdentityKind, PolyID,
PolynomialReference, PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator,
self, Analyzed, BinaryOperator, Expression, IdentityKind, PolyID, PolynomialReference,
PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator,
};
use starky::types::{
ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity,
Expand Down Expand Up @@ -46,18 +46,13 @@ pub fn export<T: FieldElement>(analyzed: &Analyzed<T>) -> PIL {
for item in &analyzed.source_order {
match item {
StatementIdentifier::Definition(name) => {
if let (poly, Some(value)) = &analyzed.definitions[name] {
if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) {
if let FunctionValueDefinition::Expression(value) = value {
let expression_id = exporter.extract_expression(value, 1);
assert_eq!(
expression_id,
exporter.intermediate_poly_expression_ids[&poly.id] as usize
);
} else {
panic!("Expected single value");
}
}
if let Some((poly, value)) = analyzed.intermediate_columns.get(name) {
assert_eq!(poly.kind, SymbolKind::Poly(PolynomialType::Intermediate));
let expression_id = exporter.extract_expression(value, 1);
assert_eq!(
expression_id,
exporter.intermediate_poly_expression_ids[&poly.id] as usize
);
}
}
StatementIdentifier::PublicDeclaration(name) => {
Expand Down Expand Up @@ -187,7 +182,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
.filter_map(|(name, (symbol, _value))| {
let id = match symbol.kind {
SymbolKind::Poly(PolynomialType::Intermediate) => {
Some(self.intermediate_poly_expression_ids[&symbol.id])
panic!("Should be in intermediates")
}
SymbolKind::Poly(_) => Some(symbol.id),
SymbolKind::Other() | SymbolKind::Constant() => None,
Expand All @@ -204,6 +199,26 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
};
Some((name.clone(), out))
})
.chain(
self.analyzed
.intermediate_columns
.iter()
.map(|(name, (symbol, _))| {
assert_eq!(symbol.kind, SymbolKind::Poly(PolynomialType::Intermediate));
let id = self.intermediate_poly_expression_ids[&symbol.id];

let out = Reference {
polType: None,
type_: symbol_kind_to_json_string(symbol.kind).to_string(),
id: id as usize,
polDeg: symbol.degree as usize,
isArray: symbol.is_array(),
elementType: None,
len: symbol.length.map(|l| l as usize),
};
(name.clone(), out)
}),
)
.collect::<HashMap<String, Reference>>()
}

Expand Down Expand Up @@ -368,6 +383,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
#[cfg(test)]
mod test {
use pil_analyzer::analyze;
use pretty_assertions::assert_eq;
use serde_json::Value as JsonValue;
use std::{fs, process::Command};
use test_log::test;
Expand Down
2 changes: 2 additions & 0 deletions executor/src/constant_evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ fn generate_values<T: FieldElement>(
.map(|i| {
Evaluator {
definitions: &analyzed.definitions,
intermediate_columns: &analyzed.intermediate_columns,
variables: &[i.into()],
function_cache: other_constants,
}
Expand All @@ -52,6 +53,7 @@ fn generate_values<T: FieldElement>(
FunctionValueDefinition::Array(values) => {
let evaluator = Evaluator {
definitions: &analyzed.definitions,
intermediate_columns: &analyzed.intermediate_columns,
variables: &[],
function_cache: other_constants,
};
Expand Down
24 changes: 22 additions & 2 deletions pil_analyzer/src/condenser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::HashMap;
use ast::{
analyzed::{
Analyzed, Expression, FunctionValueDefinition, Identity, PolynomialReference,
PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind,
PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind,
},
evaluate_binary_operation, evaluate_unary_operation,
parsed::{visitor::ExpressionVisitable, SelectedExpressions},
Expand All @@ -22,7 +22,7 @@ pub fn condense<T: FieldElement>(
source_order: Vec<StatementIdentifier>,
) -> Analyzed<T> {
let condenser = Condenser {
constants: compute_constants(&definitions),
constants: compute_constants(&definitions, &Default::default()),
symbols: definitions
.iter()
.map(|(name, (symbol, _))| (name.clone(), symbol.clone()))
Expand All @@ -34,6 +34,25 @@ pub fn condense<T: FieldElement>(
.map(|identity| condenser.condense_identity(identity))
.collect();

// Extract intermediate columns
let intermediate_columns: HashMap<_, _> = definitions
.iter()
.filter_map(|(name, (symbol, definition))| {
if matches!(symbol.kind, SymbolKind::Poly(PolynomialType::Intermediate)) {
let Some(FunctionValueDefinition::Expression(e)) = definition else {
panic!("Expected expression")
};
Some((
name.clone(),
(symbol.clone(), condenser.condense_expression(e.clone())),
))
} else {
None
}
})
.collect();
definitions.retain(|name, _| !intermediate_columns.contains_key(name));

definitions.values_mut().for_each(|(_, definition)| {
if let Some(def) = definition {
def.post_visit_expressions_mut(&mut |e| {
Expand All @@ -50,6 +69,7 @@ pub fn condense<T: FieldElement>(
Analyzed {
definitions,
public_declarations,
intermediate_columns,
identities,
source_order,
}
Expand Down
Loading

0 comments on commit 7721ec8

Please sign in to comment.