Skip to content

Commit

Permalink
prover type
Browse files Browse the repository at this point in the history
  • Loading branch information
ratankaliani committed May 30, 2024
1 parent b954a47 commit 52c9c94
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use strum_macros::EnumString;
pub struct ProverClient {
/// The underlying prover implementation.
pub prover: Box<dyn Prover>,
pub prover_type: ProverType,
}

/// The type of prover used by the [ProverClient].
Expand Down Expand Up @@ -96,12 +97,15 @@ impl ProverClient {
{
"mock" => Self {
prover: Box::new(MockProver::new()),
prover_type: ProverType::Mock,
},
"local" => Self {
prover: Box::new(LocalProver::new()),
prover_type: ProverType::Local,
},
"network" => Self {
prover: Box::new(NetworkProver::new()),
prover_type: ProverType::Network,
},
_ => panic!(
"invalid value for SP1_PROVER enviroment variable: expected 'local', 'mock', or 'remote'"
Expand All @@ -124,6 +128,7 @@ impl ProverClient {
pub fn mock() -> Self {
Self {
prover: Box::new(MockProver::new()),
prover_type: ProverType::Mock,
}
}

Expand All @@ -142,6 +147,7 @@ impl ProverClient {
pub fn local() -> Self {
Self {
prover: Box::new(LocalProver::new()),
prover_type: ProverType::Local,
}
}

Expand All @@ -159,30 +165,7 @@ impl ProverClient {
pub fn remote() -> Self {
Self {
prover: Box::new(NetworkProver::new()),
}
}

/// Returns the type of prover used by the [ProverClient].
///
/// ### Examples
///
/// ```no_run
/// use sp1_sdk::ProverClient;
///
/// let client = ProverClient::local();
/// let prover_type = client.prover_type();
/// assert_eq!(prover_type, ProverType::Local);
/// ```
pub fn prover_type(&self) -> ProverType {
let prover_type_id = (*self.prover).type_id();
if prover_type_id == TypeId::of::<LocalProver>() {
ProverType::Local
} else if prover_type_id == TypeId::of::<MockProver>() {
ProverType::Mock
} else if prover_type_id == TypeId::of::<NetworkProver>() {
ProverType::Network
} else {
panic!("invalid prover type")
prover_type: ProverType::Network,
}
}

Expand Down

0 comments on commit 52c9c94

Please sign in to comment.