From 1b29f6f76ccebecf30dcdfe4182582a466c71785 Mon Sep 17 00:00:00 2001 From: Guillaume Ballet <3272758+gballet@users.noreply.github.com> Date: Thu, 5 Oct 2023 12:12:32 +0200 Subject: [PATCH] poseidon gl rust --- compiler/benches/executor_benchmark.rs | 4 +- powdr_cli/src/main.rs | 49 +++- riscv/runtime/src/coprocessors.rs | 34 ++- riscv/src/compiler.rs | 106 +++---- riscv/src/coprocessors.rs | 272 ++++++++++++++++++ riscv/src/lib.rs | 16 +- riscv/tests/instructions.rs | 6 +- riscv/tests/riscv.rs | 36 ++- riscv/tests/riscv_data/poseidon.rs | 12 - .../riscv_data/poseidon_gl_via_coprocessor.rs | 68 +++++ 10 files changed, 494 insertions(+), 109 deletions(-) create mode 100644 riscv/src/coprocessors.rs delete mode 100644 riscv/tests/riscv_data/poseidon.rs create mode 100644 riscv/tests/riscv_data/poseidon_gl_via_coprocessor.rs diff --git a/compiler/benches/executor_benchmark.rs b/compiler/benches/executor_benchmark.rs index 68e8f485e..4caf72573 100644 --- a/compiler/benches/executor_benchmark.rs +++ b/compiler/benches/executor_benchmark.rs @@ -8,13 +8,15 @@ use mktemp::Temp; use number::{FieldElement, GoldilocksField}; use riscv::{compile_rust_crate_to_riscv_asm, compiler}; +use riscv::CoProcessors; + type T = GoldilocksField; fn get_pil() -> Analyzed { let tmp_dir = Temp::new_dir().unwrap(); let riscv_asm_files = compile_rust_crate_to_riscv_asm("../riscv/tests/riscv_data/keccak/Cargo.toml", &tmp_dir); - let contents = compiler::compile(riscv_asm_files); + let contents = compiler::compile(riscv_asm_files, &CoProcessors::base()); let parsed = parser::parse_asm::(None, &contents).unwrap(); let resolved = importer::resolve(None, parsed).unwrap(); let analyzed = analyze(resolved).unwrap(); diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index 9c646c5bb..d2e5fb214 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -122,6 +122,10 @@ enum Commands { #[arg(short, long)] #[arg(value_parser = clap_enum_variants!(BackendType))] prove_with: Option, + + /// Comma-separated list of coprocessors. + #[arg(long)] + coprocessors: Option, }, /// Compiles riscv assembly to powdr assembly and then to PIL @@ -156,6 +160,10 @@ enum Commands { #[arg(short, long)] #[arg(value_parser = clap_enum_variants!(BackendType))] prove_with: Option, + + /// Comma-separated list of coprocessors. + #[arg(long)] + coprocessors: Option, }, Prove { @@ -285,13 +293,21 @@ fn run_command(command: Commands) { output_directory, force, prove_with, + coprocessors, } => { + let coprocessors = match coprocessors { + Some(list) => { + riscv::CoProcessors::try_from(list.split(',').collect::>()).unwrap() + } + None => riscv::CoProcessors::base(), + }; if let Err(errors) = call_with_field!(run_rust::( &file, split_inputs(&inputs), Path::new(&output_directory), force, - prove_with + prove_with, + coprocessors )) { eprintln!("Errors:"); for e in errors { @@ -306,6 +322,7 @@ fn run_command(command: Commands) { output_directory, force, prove_with, + coprocessors, } => { assert!(!files.is_empty()); let name = if files.len() == 1 { @@ -314,13 +331,20 @@ fn run_command(command: Commands) { Cow::Borrowed("output") }; + let coprocessors = match coprocessors { + Some(list) => { + riscv::CoProcessors::try_from(list.split(',').collect::>()).unwrap() + } + None => riscv::CoProcessors::base(), + }; if let Err(errors) = call_with_field!(run_riscv_asm::( &name, files.into_iter(), split_inputs(&inputs), Path::new(&output_directory), force, - prove_with + prove_with, + coprocessors )) { eprintln!("Errors:"); for e in errors { @@ -386,7 +410,7 @@ fn run_command(command: Commands) { } => { call_with_field!(setup::(size, dir, backend)); } - } + }; } fn setup(size: u64, dir: String, backend_type: BackendType) { @@ -403,15 +427,18 @@ fn write_backend_to_fs(be: &dyn Backend, output_dir: &Path) params_writer.flush().unwrap(); log::info!("Wrote params.bin."); } + fn run_rust( file_name: &str, inputs: Vec, output_dir: &Path, force_overwrite: bool, prove_with: Option, + coprocessors: riscv::CoProcessors, ) -> Result<(), Vec> { - let (asm_file_path, asm_contents) = compile_rust(file_name, output_dir, force_overwrite) - .ok_or_else(|| vec!["could not compile rust".to_string()])?; + let (asm_file_path, asm_contents) = + compile_rust(file_name, output_dir, force_overwrite, &coprocessors) + .ok_or_else(|| vec!["could not compile rust".to_string()])?; compile_asm_string( asm_file_path.to_str().unwrap(), @@ -431,10 +458,16 @@ fn run_riscv_asm( output_dir: &Path, force_overwrite: bool, prove_with: Option, + coprocessors: riscv::CoProcessors, ) -> Result<(), Vec> { - let (asm_file_path, asm_contents) = - compile_riscv_asm(original_file_name, file_names, output_dir, force_overwrite) - .ok_or_else(|| vec!["could not compile RISC-V assembly".to_string()])?; + let (asm_file_path, asm_contents) = compile_riscv_asm( + original_file_name, + file_names, + output_dir, + force_overwrite, + &coprocessors, + ) + .ok_or_else(|| vec!["could not compile RISC-V assembly".to_string()])?; compile_asm_string( asm_file_path.to_str().unwrap(), diff --git a/riscv/runtime/src/coprocessors.rs b/riscv/runtime/src/coprocessors.rs index f9ebc152f..9e0c7eced 100644 --- a/riscv/runtime/src/coprocessors.rs +++ b/riscv/runtime/src/coprocessors.rs @@ -1,12 +1,36 @@ // This is a dummy implementation of Poseidon hash, -// which will be replaced with a call to the poseidon -// coporocessor during compilation. +// which will be replaced with a call to the Poseidon +// coprocessor during compilation. // The function itself will be removed by the compiler // during the reachability analysis. extern "C" { - fn poseidon_coprocessor(a: u32, b: u32) -> u32; + fn poseidon_gl_coprocessor(data: *mut [u64; 12]); } -pub fn poseidon_hash(a: u32, b: u32) -> u32 { - unsafe { poseidon_coprocessor(a, b) } +const GOLDILOCKS: u64 = 0xffffffff00000001; + +/// Calls the low level Poseidon coprocessor in PIL, where +/// the last 4 elements are the "cap" +/// and the return value is placed in data[0:4]. +/// The safe version below also checks that each u64 element +/// is less than the Goldilocks field. +/// The unsafe version does not perform such checks. +pub fn poseidon_gl(mut data: [u64; 12]) -> [u64; 4] { + for &n in data.iter() { + assert!(n < GOLDILOCKS); + } + + unsafe { + poseidon_gl_coprocessor(&mut data as *mut [u64; 12]); + } + + [data[0], data[1], data[2], data[3]] +} + +pub fn poseidon_gl_unsafe(mut data: [u64; 12]) -> [u64; 4] { + unsafe { + poseidon_gl_coprocessor(&mut data as *mut [u64; 12]); + } + + [data[0], data[1], data[2], data[3]] } diff --git a/riscv/src/compiler.rs b/riscv/src/compiler.rs index 8688d3eb3..43d953130 100644 --- a/riscv/src/compiler.rs +++ b/riscv/src/compiler.rs @@ -16,6 +16,7 @@ use asm_utils::{ }; use itertools::Itertools; +use crate::coprocessors::*; use crate::disambiguator; use crate::parser::RiscParser; use crate::{Argument, Expression, Statement}; @@ -95,19 +96,15 @@ impl Architecture for RiscvArchitecture { } } -pub fn machine_decls() -> Vec<&'static str> { - vec!["use std::binary::Binary;", "use std::shift::Shift;"] -} - /// Compiles riscv assembly to a powdr assembly file. Adds required library routines. -pub fn compile(mut assemblies: BTreeMap) -> String { +pub fn compile(mut assemblies: BTreeMap, coprocessors: &CoProcessors) -> String { // stack grows towards zero let stack_start = 0x10000; // data grows away from zero let data_start = 0x10100; assert!(assemblies - .insert("__runtime".to_string(), runtime().to_string()) + .insert("__runtime".to_string(), runtime(coprocessors)) .is_none()); // TODO remove unreferenced files. @@ -133,7 +130,7 @@ pub fn compile(mut assemblies: BTreeMap) -> String { // Remove the riscv asm stub function, which is used // for compilation, and will not be called. - statements = replace_coprocessor_stubs(statements).collect::>(); + statements = replace_coprocessor_stubs(statements, coprocessors).collect::>(); // Sort the objects according to the order of the names in object_order. // With the single exception: If there is large object, put that at the end. @@ -190,14 +187,14 @@ pub fn compile(mut assemblies: BTreeMap) -> String { ); riscv_machine( - &machine_decls(), - &preamble(), - &[("binary", "Binary"), ("shift", "Shift")], + &coprocessors.machine_imports(), + &preamble(coprocessors), + &coprocessors.declarations(), file_ids .into_iter() .map(|(id, dir, file)| format!("debug file {id} {} {};", quote(&dir), quote(&file))) .chain(["call __data_init;".to_string()]) - .chain(call_every_submachine()) + .chain(call_every_submachine(coprocessors)) .chain([ format!("// Set stack pointer\nx2 <=X= {stack_start};"), "call __runtime_start;".to_string(), @@ -206,7 +203,7 @@ pub fn compile(mut assemblies: BTreeMap) -> String { .chain( substitute_symbols_with_values(statements, &data_positions) .into_iter() - .flat_map(process_statement), + .flat_map(|v| process_statement(v, coprocessors)), ) .chain(["// This is the data initialization routine.\n__data_init::".to_string()]) .chain(data_code) @@ -318,31 +315,17 @@ where .filter_map(|(filter, statement)| (!filter).then_some(statement)) } -fn replace_coprocessor_stubs( - statements: impl IntoIterator, -) -> impl Iterator { - let stub_names: Vec<&str> = COPROCESSOR_SUBSTITUTIONS - .iter() - .map(|(name, _)| *name) - .collect(); +fn replace_coprocessor_stubs<'a>( + statements: impl IntoIterator + 'a, + coprocessors: &'a CoProcessors, +) -> impl Iterator + 'a { + let stub_names: Vec<&'a str> = coprocessors.runtime_names(); remove_matching_and_next(statements.into_iter(), move |statement| -> bool { matches!(&statement, Statement::Label(label) if stub_names.contains(&label.as_str())) }) } -fn call_every_submachine() -> Vec { - // TODO This is a hacky snippet to ensure that every submachine in the RISCV machine - // is called at least once. This is needed for witgen until it can do default blocks - // automatically. - // https://github.com/powdr-labs/powdr/issues/548 - vec![ - "x10 <== and(x10, x10);".to_string(), - "x10 <== shl(x10, x10);".to_string(), - "x10 <=X= 0;".to_string(), - ] -} - fn substitute_symbols_with_values( mut statements: Vec, data_positions: &BTreeMap, @@ -435,7 +418,7 @@ machine Main {{ ) } -fn preamble() -> String { +fn preamble(coprocessors: &CoProcessors) -> String { r#" degree 262144; reg pc[@pc]; @@ -443,12 +426,17 @@ fn preamble() -> String { reg Y[<=]; reg Z[<=]; reg W[<=]; +"# + .to_string() + + &coprocessors.registers() + + &r#" reg tmp1; reg tmp2; reg tmp3; reg lr_sc_reservation; "# - .to_string() + .to_owned() + .to_string() + &(0..32) .map(|i| format!("\t\treg x{i};\n")) .collect::>() @@ -552,28 +540,8 @@ fn preamble() -> String { instr is_not_equal_zero X -> Y { Y = 1 - XIsZero } // ================= coprocessor substitution instructions ================= - - instr poseidon Y, Z -> X { - // Dummy code, to be replaced with actual poseidon code. - X = 0 - } - - // ================= binary/bitwise instructions ================= - - instr and Y, Z -> X = binary.and - - instr or Y, Z -> X = binary.or - - instr xor Y, Z -> X = binary.xor - - // ================= shift instructions ================= - - instr shl Y, Z -> X = shift.shl - - instr shr Y, Z -> X = shift.shr - - // ================== wrapping instructions ============== - +"# + &coprocessors.instructions() + + r#" // Wraps a value in Y to 32 bits. // Requires 0 <= Y < 2**33 instr wrap Y -> X { Y = X + wrap_bit * 2**32, X = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000 } @@ -679,7 +647,7 @@ fn preamble() -> String { "# } -fn runtime() -> &'static str { +fn runtime(coprocessors: &CoProcessors) -> String { r#" .globl __udivdi3@plt .globl __udivdi3 @@ -727,14 +695,12 @@ fn runtime() -> &'static str { .globl __rust_alloc_error_handler .set __rust_alloc_error_handler, __rg_oom - -.globl poseidon_coprocessor -poseidon_coprocessor: - ret "# + .to_owned() + + &coprocessors.runtime() } -fn process_statement(s: Statement) -> Vec { +fn process_statement(s: Statement, coprocessors: &CoProcessors) -> Vec { match &s { Statement::Label(l) => vec![format!("{}::", escape_label(l))], Statement::Directive(directive, args) => match (directive.as_str(), &args[..]) { @@ -754,7 +720,7 @@ fn process_statement(s: Statement) -> Vec { args.iter().format(", ") ), }, - Statement::Instruction(instr, args) => process_instruction(instr, args) + Statement::Instruction(instr, args) => process_instruction(instr, args, coprocessors) .into_iter() .map(|s| " ".to_string() + &s) .collect(), @@ -847,17 +813,15 @@ fn only_if_no_write_to_zero_vec(statements: Vec, reg: Register) -> Vec Option { - COPROCESSOR_SUBSTITUTIONS +fn try_coprocessor_substitution(label: &str, coprocessors: &CoProcessors) -> Option { + coprocessors + .substitutions() .iter() .find(|(l, _)| *l == label) - .map(|&(_, subst)| subst.to_string()) + .map(|(_, subst)| subst.to_string()) } -fn process_instruction(instr: &str, args: &[Argument]) -> Vec { +fn process_instruction(instr: &str, args: &[Argument], coprocessors: &CoProcessors) -> Vec { match instr { // load/store registers "li" => { @@ -1175,7 +1139,9 @@ fn process_instruction(instr: &str, args: &[Argument]) -> Vec { assert_eq!(args.len(), 1); let label = &args[0]; let replacement = match label { - Argument::Expression(Expression::Symbol(l)) => try_coprocessor_substitution(l), + Argument::Expression(Expression::Symbol(l)) => { + try_coprocessor_substitution(l, coprocessors) + } _ => None, }; match (replacement, instr) { diff --git a/riscv/src/coprocessors.rs b/riscv/src/coprocessors.rs new file mode 100644 index 000000000..a2d796f6d --- /dev/null +++ b/riscv/src/coprocessors.rs @@ -0,0 +1,272 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + convert::TryFrom, +}; + +type RuntimeFunctionImpl = (&'static str, fn() -> String); + +#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)] +struct CoProcessor { + name: &'static str, + ty: &'static str, + import: &'static str, + instructions: &'static str, + runtime_function_impl: Option, +} + +static BINARY_COPROCESSOR: CoProcessor = CoProcessor { + name: "binary", + ty: "Binary", + import: "use std::binary::Binary;", + instructions: r#" + // ================= binary/bitwise instructions ================= + instr and Y, Z -> X = binary.and + instr or Y, Z -> X = binary.or + instr xor Y, Z -> X = binary.xor + + "#, + runtime_function_impl: None, +}; + +static SHIFT_COPROCESSOR: CoProcessor = CoProcessor { + name: "shift", + ty: "Shift", + import: "use std::shift::Shift;", + instructions: r#" + // ================= shift instructions ================= + instr shl Y, Z -> X = shift.shl + instr shr Y, Z -> X = shift.shr + + "#, + runtime_function_impl: None, +}; + +static SPLIT_GL_COPROCESSOR: CoProcessor = CoProcessor { + name: "split_gl", + ty: "SplitGL", + import: "use std::split::split_gl::SplitGL;", + instructions: r#" +// ================== wrapping instructions ============== +instr split_gl Z -> X, Y = split_gl.split + + "#, + runtime_function_impl: None, +}; + +static POSEIDON_GL_COPROCESSOR: CoProcessor = CoProcessor { + name: "poseidon_gl", + ty: "PoseidonGL", + import: "use std::hash::poseidon_gl::PoseidonGL;", + instructions: r#" +// ================== hashing instructions ============== +instr poseidon_gl A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11 -> X, Y, Z, W = poseidon_gl.poseidon_permutation + +"#, + runtime_function_impl: Some(("poseidon_gl_coprocessor", poseidon_gl_call)), +}; + +static ALL_COPROCESSORS: [(&str, &CoProcessor); 4] = [ + (BINARY_COPROCESSOR.name, &BINARY_COPROCESSOR), + (SHIFT_COPROCESSOR.name, &SHIFT_COPROCESSOR), + (SPLIT_GL_COPROCESSOR.name, &SPLIT_GL_COPROCESSOR), + (POSEIDON_GL_COPROCESSOR.name, &POSEIDON_GL_COPROCESSOR), +]; + +/// Defines which coprocessors should be used by the RISCV machine. +/// It is important to not add unused coprocessors since they may +/// lead to many extra columns in PIL. +#[derive(Default)] +pub struct CoProcessors { + coprocessors: BTreeMap<&'static str, &'static CoProcessor>, +} + +impl TryFrom> for CoProcessors { + type Error = String; + + fn try_from(list: Vec<&str>) -> Result { + let items: BTreeSet<&str> = list.into_iter().collect(); + + if !items.iter().all(|co_processor| { + ALL_COPROCESSORS + .iter() + .any(|(name, _)| co_processor == name) + }) { + return Err("Invalid co-processor specified.".to_string()); + } + + Ok(Self { + coprocessors: ALL_COPROCESSORS + .iter() + .filter_map(|(name, co_processor)| { + if items.contains(name) { + Some((*name, *co_processor)) + } else { + None + } + }) + .collect(), + }) + } +} + +impl CoProcessors { + /// The base version only adds the commonly used bitwise and shift operations. + pub fn base() -> CoProcessors { + Self { + coprocessors: BTreeMap::from([ + (BINARY_COPROCESSOR.name, &BINARY_COPROCESSOR), + (SHIFT_COPROCESSOR.name, &SHIFT_COPROCESSOR), + ]), + } + } + + /// Poseidon also uses the Split machine. + pub fn with_poseidon(mut self) -> Self { + self.coprocessors + .insert(SPLIT_GL_COPROCESSOR.name, &SPLIT_GL_COPROCESSOR); + self.coprocessors + .insert(POSEIDON_GL_COPROCESSOR.name, &POSEIDON_GL_COPROCESSOR); + self + } + + pub fn has(&self, key: &str) -> bool { + self.coprocessors.contains_key(key) + } + + pub fn declarations(&self) -> Vec<(&'static str, &'static str)> { + self.coprocessors.values().map(|c| (c.name, c.ty)).collect() + } + + pub fn machine_imports(&self) -> Vec<&'static str> { + self.coprocessors.values().map(|c| c.import).collect() + } + + pub fn instructions(&self) -> String { + self.coprocessors + .values() + .map(|c| c.instructions) + .collect::>() + .join("") + } + + pub fn runtime_names(&self) -> Vec<&str> { + self.coprocessors + .values() + .filter_map(|c| c.runtime_function_impl) + .map(|f| f.0) + .collect() + } + + pub fn runtime(&self) -> String { + self.runtime_names() + .iter() + .map(|f| { + format!( + r#" + .globl {} + {}: + ret + "#, + f, f + ) + }) + .collect::>() + .join("\n") + } + + pub fn substitutions(&self) -> Vec<(&'static str, String)> { + self.coprocessors + .values() + .filter_map(|c| c.runtime_function_impl) + .map(|f| (f.0, f.1())) + .collect() + } + + pub fn registers(&self) -> String { + // Poseidon has 12 inputs and 4 outputs. + // The base RISCV machine has 4 assignment registers. + // Therefore we need to add 12 assignment registers when using Poseidon. + // Moreover, we also need 12 extra general purpose registers to store the + // input values. + + if !self.coprocessors.contains_key(POSEIDON_GL_COPROCESSOR.name) { + return String::new(); + } + + let a_regs: Vec = (0..12).map(|i| format!("reg A{}[<=];", i)).collect(); + let p_regs: Vec = (0..12).map(|i| format!("reg P{};", i)).collect(); + + [a_regs, p_regs].concat().join("\n") + } +} + +fn poseidon_gl_call() -> String { + let decoding = |i| { + format!( + r#" + addr <=X= {} + x10; + P{i} <== mload(); + addr <=X= {} + x10; + tmp1 <== mload(); + P{i} <=X= P{i} + tmp1 * 2**32; + "#, + i * 8, + i * 8 + 4 + ) + }; + + let encoding = |i| { + format!( + r#" + tmp1, tmp2 <== split_gl(P{i}); + addr <=X= {} + x10; + mstore tmp1; + addr <=X= {} + x10; + mstore tmp2; + "#, + i * 8, + i * 8 + 4 + ) + }; + + let call = "P0, P1, P2, P3 <== poseidon_gl(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11);"; + + (0..12) + .map(decoding) + .chain(std::iter::once(call.to_string())) + .chain((0..4).map(encoding)) + .collect() +} + +// This could also potentially go in the impl of CoProcessors, +// but I purposefully left it outside because it should be removed eventually. +pub fn call_every_submachine(coprocessors: &CoProcessors) -> Vec { + // TODO This is a hacky snippet to ensure that every submachine in the RISCV machine + // is called at least once. This is needed for witgen until it can do default blocks + // automatically. + // https://github.com/powdr-labs/powdr/issues/548 + let mut calls = vec![]; + if coprocessors.has(BINARY_COPROCESSOR.name) { + calls.push("x10 <== and(x10, x10);".to_string()); + } + if coprocessors.has(SHIFT_COPROCESSOR.name) { + calls.push("x10 <== shl(x10, x10);".to_string()); + } + if coprocessors.has(SPLIT_GL_COPROCESSOR.name) { + calls.push("x10, x11 <== split_gl(x10);".to_string()); + } + if coprocessors.has(POSEIDON_GL_COPROCESSOR.name) { + calls.extend(vec![ + "P0, P1, P2, P3 <== poseidon_gl(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11);" + .to_string(), + "P0 <=X= 0;".to_string(), + "P1 <=X= 0;".to_string(), + "P2 <=X= 0;".to_string(), + "P3 <=X= 0;".to_string(), + ]); + } + + calls.extend(vec!["x10 <=X= 0;".to_string(), "x11 <=X= 0;".to_string()]); + + calls +} diff --git a/riscv/src/lib.rs b/riscv/src/lib.rs index b1670718d..0a6dea58d 100644 --- a/riscv/src/lib.rs +++ b/riscv/src/lib.rs @@ -13,8 +13,10 @@ use serde_json::Value as JsonValue; use std::fs; use crate::compiler::{FunctionKind, Register}; +pub use crate::coprocessors::CoProcessors; pub mod compiler; +mod coprocessors; mod disambiguator; pub mod parser; @@ -28,6 +30,7 @@ pub fn compile_rust( file_name: &str, output_dir: &Path, force_overwrite: bool, + coprocessors: &CoProcessors, ) -> Option<(PathBuf, String)> { let riscv_asm = if file_name.ends_with("Cargo.toml") { compile_rust_crate_to_riscv_asm(file_name, output_dir) @@ -56,7 +59,13 @@ pub fn compile_rust( log::info!("Wrote {}", riscv_asm_file_name.to_str().unwrap()); } - compile_riscv_asm_bundle(file_name, riscv_asm, output_dir, force_overwrite) + compile_riscv_asm_bundle( + file_name, + riscv_asm, + output_dir, + force_overwrite, + coprocessors, + ) } pub fn compile_riscv_asm_bundle( @@ -64,6 +73,7 @@ pub fn compile_riscv_asm_bundle( riscv_asm_files: BTreeMap, output_dir: &Path, force_overwrite: bool, + coprocessors: &CoProcessors, ) -> Option<(PathBuf, String)> { let powdr_asm_file_name = output_dir.join(format!( "{}.asm", @@ -81,7 +91,7 @@ pub fn compile_riscv_asm_bundle( return None; } - let powdr_asm = compiler::compile(riscv_asm_files); + let powdr_asm = compiler::compile(riscv_asm_files, coprocessors); fs::write(powdr_asm_file_name.clone(), &powdr_asm).unwrap(); log::info!("Wrote {}", powdr_asm_file_name.to_str().unwrap()); @@ -96,6 +106,7 @@ pub fn compile_riscv_asm( file_names: impl Iterator, output_dir: &Path, force_overwrite: bool, + coprocessors: &CoProcessors, ) -> Option<(PathBuf, String)> { compile_riscv_asm_bundle( original_file_name, @@ -107,6 +118,7 @@ pub fn compile_riscv_asm( .collect(), output_dir, force_overwrite, + coprocessors, ) } diff --git a/riscv/tests/instructions.rs b/riscv/tests/instructions.rs index db443245d..9248d6c17 100644 --- a/riscv/tests/instructions.rs +++ b/riscv/tests/instructions.rs @@ -2,11 +2,15 @@ mod instruction_tests { use compiler::verify_asm_string; use number::GoldilocksField; use riscv::compiler::compile; + use riscv::CoProcessors; use test_log::test; fn run_instruction_test(assembly: &str, name: &str) { // TODO Should we create one powdr-asm from all tests or keep them separate? - let powdr_asm = compile([(name.to_string(), assembly.to_string())].into()); + let powdr_asm = compile( + [(name.to_string(), assembly.to_string())].into(), + &CoProcessors::base(), + ); verify_asm_string::(&format!("{name}.asm"), &powdr_asm, vec![]); } diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 5df9ae0fb..cdd5ae69e 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -3,11 +3,20 @@ use mktemp::Temp; use number::GoldilocksField; use test_log::test; +use riscv::CoProcessors; + #[test] #[ignore = "Too slow"] fn test_trivial() { let case = "trivial.rs"; - verify_file(case, vec![]); + verify_file(case, vec![], &CoProcessors::base()) +} + +#[test] +#[ignore = "Too slow"] +fn test_poseidon_gl() { + let case = "poseidon_gl_via_coprocessor.rs"; + verify_file(case, vec![], &CoProcessors::base().with_poseidon()); } #[test] @@ -17,6 +26,7 @@ fn test_sum() { verify_file( case, [16, 4, 1, 2, 8, 5].iter().map(|&x| x.into()).collect(), + &CoProcessors::base(), ); } @@ -24,7 +34,11 @@ fn test_sum() { #[ignore = "Too slow"] fn test_byte_access() { let case = "byte_access.rs"; - verify_file(case, [0, 104, 707].iter().map(|&x| x.into()).collect()); + verify_file( + case, + [0, 104, 707].iter().map(|&x| x.into()).collect(), + &CoProcessors::base(), + ); } #[test] @@ -49,6 +63,7 @@ fn test_double_word() { .iter() .map(|&x| x.into()) .collect(), + &CoProcessors::base(), ); } @@ -56,14 +71,14 @@ fn test_double_word() { #[ignore = "Too slow"] fn test_memfuncs() { let case = "memfuncs"; - verify_crate(case, vec![]); + verify_crate(case, vec![], &CoProcessors::base()); } #[test] #[ignore = "Too slow"] fn test_keccak() { let case = "keccak"; - verify_crate(case, vec![]); + verify_crate(case, vec![], &CoProcessors::base()); } #[test] @@ -76,6 +91,7 @@ fn test_vec_median() { .into_iter() .map(|x| x.into()) .collect(), + &CoProcessors::base(), ); } @@ -83,7 +99,7 @@ fn test_vec_median() { #[ignore = "Too slow"] fn test_password() { let case = "password_checker"; - verify_crate(case, vec![]); + verify_crate(case, vec![], &CoProcessors::base()); } // TODO: uncomment this when we properly support revm, so we don't break nightly @@ -101,25 +117,25 @@ fn test_evm() { #[should_panic(expected = "Witness generation failed.")] fn test_print() { let case = "print.rs"; - verify_file(case, vec![]); + verify_file(case, vec![], &CoProcessors::base()); } -fn verify_file(case: &str, inputs: Vec) { +fn verify_file(case: &str, inputs: Vec, coprocessors: &CoProcessors) { let temp_dir = Temp::new_dir().unwrap(); let riscv_asm = riscv::compile_rust_to_riscv_asm(&format!("tests/riscv_data/{case}"), &temp_dir); - let powdr_asm = riscv::compiler::compile(riscv_asm); + let powdr_asm = riscv::compiler::compile(riscv_asm, coprocessors); verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); } -fn verify_crate(case: &str, inputs: Vec) { +fn verify_crate(case: &str, inputs: Vec, coprocessors: &CoProcessors) { let temp_dir = Temp::new_dir().unwrap(); let riscv_asm = riscv::compile_rust_crate_to_riscv_asm( &format!("tests/riscv_data/{case}/Cargo.toml"), &temp_dir, ); - let powdr_asm = riscv::compiler::compile(riscv_asm); + let powdr_asm = riscv::compiler::compile(riscv_asm, coprocessors); verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); } diff --git a/riscv/tests/riscv_data/poseidon.rs b/riscv/tests/riscv_data/poseidon.rs deleted file mode 100644 index b142936b2..000000000 --- a/riscv/tests/riscv_data/poseidon.rs +++ /dev/null @@ -1,12 +0,0 @@ -#![no_std] - -use runtime::coprocessors::poseidon_hash; - -#[no_mangle] -pub fn main() { - let h = poseidon_hash(1, 2); - // This is the value returned by the coprocessor stub, - // this needs to be updated when the final version is - // merged. - assert_eq!(h, 0); -} \ No newline at end of file diff --git a/riscv/tests/riscv_data/poseidon_gl_via_coprocessor.rs b/riscv/tests/riscv_data/poseidon_gl_via_coprocessor.rs new file mode 100644 index 000000000..a8f652de9 --- /dev/null +++ b/riscv/tests/riscv_data/poseidon_gl_via_coprocessor.rs @@ -0,0 +1,68 @@ +#![no_std] + +use runtime::coprocessors::{poseidon_gl, poseidon_gl_unsafe}; + +#[no_mangle] +fn main() { + let i: [u64; 12] = [0; 12]; + let h = poseidon_gl(i); + assert_eq!(h[0], 4330397376401421145); + assert_eq!(h[1], 14124799381142128323); + assert_eq!(h[2], 8742572140681234676); + assert_eq!(h[3], 14345658006221440202); + + let i: [u64; 12] = [1; 12]; + let h = poseidon_gl(i); + assert_eq!(h[0], 16428316519797902711); + assert_eq!(h[1], 13351830238340666928); + assert_eq!(h[2], 682362844289978626); + assert_eq!(h[3], 12150588177266359240); + + let minus_one = 0xffffffff00000001 - 1; + let i: [u64; 12] = [minus_one; 12]; + let h = poseidon_gl(i); + assert_eq!(h[0], 13691089994624172887); + assert_eq!(h[1], 15662102337790434313); + assert_eq!(h[2], 14940024623104903507); + assert_eq!(h[3], 10772674582659927682); + + let i: [u64; 12] = [ + 18446744069414584321, + 18446744069414584321, + 18446744069414584321, + 18446744069414584321, + 18446744069414584321, + 18446744069414584321, + 18446744069414584321, + 18446744069414584321, + 0, + 0, + 0, + 0, + ]; + let h = poseidon_gl_unsafe(i); + assert_eq!(h[0], 4330397376401421145); + assert_eq!(h[1], 14124799381142128323); + assert_eq!(h[2], 8742572140681234676); + assert_eq!(h[3], 14345658006221440202); + + let i: [u64; 12] = [ + 923978, + 235763497586, + 9827635653498, + 112870, + 289273673480943876, + 230295874986745876, + 6254867324987, + 2087, + 0, + 0, + 0, + 0, + ]; + let h = poseidon_gl(i); + assert_eq!(h[0], 1892171027578617759); + assert_eq!(h[1], 984732815927439256); + assert_eq!(h[2], 7866041765487844082); + assert_eq!(h[3], 8161503938059336191); +}