diff --git a/host/src/cli.rs b/host/src/cli.rs index ba1373c8..0b332bcb 100644 --- a/host/src/cli.rs +++ b/host/src/cli.rs @@ -177,6 +177,10 @@ pub struct RunArgs { #[clap(short, long, default_value_t = false)] /// Whether to profile the zkVM execution pub profile: bool, + + /// Optionally export cycles from executing in csv format to the file path. + #[clap(long, require_equals = true)] + pub export_csv: Option, } impl Tag for RunArgs { diff --git a/host/src/operations/build.rs b/host/src/operations/build.rs index 6d470a56..4983dc9c 100644 --- a/host/src/operations/build.rs +++ b/host/src/operations/build.rs @@ -17,10 +17,13 @@ use std::fmt::Debug; use anyhow::Context; use ethers_core::types::Transaction as EthersTransaction; use log::{info, warn}; -use risc0_zkvm::{compute_image_id, Receipt}; +use risc0_zkvm::{compute_image_id, Receipt, Session}; use serde::{Deserialize, Serialize}; use std::sync::Arc; -use tokio::sync::Semaphore; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; +use tokio::sync::{mpsc, Semaphore}; +use tokio::task::JoinSet; use zeth_lib::{ builder::BlockBuilderStrategy, consts::ChainSpec, @@ -32,7 +35,7 @@ use zeth_lib::{ const MAX_CONCURRENT_REQUESTS: usize = 5; use crate::{ - cli::{BuildArgs, Cli}, + cli::{BuildArgs, Cli, RunArgs}, operations::{execute, maybe_prove, verify_bonsai_receipt}, }; @@ -91,12 +94,21 @@ where Ok((input, output)) } +// TODO use this +enum ExecuteBlockResult { + Run(Session), + Prove((String, Receipt)), +} + /// Build a single block using the specified strategy. async fn execute_block( input: BlockBuildInput, output: BlockBuildOutput, cli: Arc, guest_elf: &'static [u8], + session_channel: Option>, + // TODO hack for expedience, can pass this cleaner by updating the return type to an enum (and remove above) + block_num: u64, ) -> anyhow::Result> where N::TxEssence: 'static + Send + TryFrom + Serialize + Deserialize<'static>, @@ -106,7 +118,7 @@ where let result = match &*cli { Cli::Build(..) => None, Cli::Run(run_args) => { - execute( + let session = execute( &input, run_args.execution_po2, run_args.profile, @@ -114,6 +126,10 @@ where &compressed_output, &cli.execution_tag(), ); + + if let Some(sender) = session_channel { + sender.send((block_num, session.user_cycles, session.total_cycles))?; + } None } Cli::Prove(..) => { @@ -154,7 +170,32 @@ where let build_args = cli.build_args().clone(); let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)); - let mut join_handles = Vec::new(); + let mut join_handles = JoinSet::new(); + + let (session_sender, writer_handle) = if let Cli::Run(RunArgs { + export_csv: Some(csv_path), + .. + }) = &*cli + { + // TODO Session type isn't sync friendly, u64s are block num, user cycles, total cycles + let (tx, mut rx) = mpsc::unbounded_channel::<(u64, u64, u64)>(); + let csv_path = csv_path.clone(); + let join_handle = tokio::spawn(async move { + let mut file = File::create(csv_path).await.unwrap(); + while let Some((block_num, user_cycles, total_cycles)) = rx.recv().await { + // Write csv record. + file.write_all( + // TODO remove leading ' -- just used for google sheets + format!("'{},{},{}\n", block_num, user_cycles, total_cycles).as_bytes(), + ) + .await + .unwrap(); + } + }); + (Some(tx), Some(join_handle)) + } else { + (None, None) + }; let block_num = build_args.block_number; // TODO semantics are a bit mixed with block count (was OP specific) @@ -166,9 +207,10 @@ where let rpc_url = rpc_url.clone(); let cli = cli.clone(); let chain_spec = chain_spec.clone(); + let session_sender = session_sender.clone(); // Spawn blocking for - join_handles.push(tokio::spawn(async move { + join_handles.spawn(async move { // Acquire permit before sending request. let _permit = semaphore.acquire().await.unwrap(); @@ -180,17 +222,38 @@ where // TODO this could be separated into a separate task, to make sure Bonsai also // doesn't get throttled, for now just going quick path of dropping permit after // preflight. - let result = execute_block::(input, output, cli, guest_elf).await; + let result = + execute_block::(input, output, cli, guest_elf, session_sender, num).await; result - })); + }); } + drop(session_sender); + // Collect responses from tasks. let mut responses = Vec::new(); - for jh in join_handles { - let response = jh.await?; - responses.push(response?); + // TODO hacky, should be one path if possible + if let Some(mut writer_handle) = writer_handle { + loop { + tokio::select! { + // Cancellation safety: `join_next` is cancel safe + Some(val) = join_handles.join_next() => { + responses.push(val??); + continue; + } + // Cancellation safety: &mut JoinHandle is cancel safe + val = &mut writer_handle => { + val?; + break; + } + else => { break } + } + } + } else { + while let Some(val) = join_handles.join_next().await { + responses.push(val??); + } } Ok(responses) diff --git a/host/src/operations/mod.rs b/host/src/operations/mod.rs index b0d9b7da..12b2b691 100644 --- a/host/src/operations/mod.rs +++ b/host/src/operations/mod.rs @@ -24,7 +24,7 @@ use risc0_zkvm::{ compute_image_id, serde::to_vec, sha::{Digest, Digestible}, - Assumption, ExecutorEnv, ExecutorImpl, Receipt, Segment, SegmentRef, + Assumption, ExecutorEnv, ExecutorImpl, Receipt, Segment, SegmentRef, Session, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use zeth_primitives::keccak::keccak; @@ -353,7 +353,7 @@ pub fn execute( elf: &[u8], expected_output: &O, profile_reference: &String, -) { +) -> Session { debug!( "Running in executor with segment_limit_po2 = {:?}", segment_limit_po2 @@ -390,7 +390,7 @@ pub fn execute( session.segments.len() * (1 << segment_limit_po2) ); // verify output - let journal = session.journal.unwrap(); + let journal = session.journal.as_ref().unwrap(); let output_guest: O = journal.decode().expect("Could not decode journal"); if expected_output == &output_guest { info!("Executor succeeded"); @@ -400,4 +400,6 @@ pub fn execute( output_guest, expected_output, ); } + + session }