Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add panic catch on operator calling FFI #1196

Merged
merged 11 commits into from
Oct 18, 2024
3 changes: 2 additions & 1 deletion operator/merkle_tree/lib/merkle_tree.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <stdbool.h>
#include <stdint.h>

bool verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root);
int32_t verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root);
25 changes: 20 additions & 5 deletions operator/merkle_tree/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use aligned_sdk::core::types::{
use lambdaworks_crypto::merkle_tree::merkle::MerkleTree;
use log::error;

#[no_mangle]
pub extern "C" fn verify_merkle_tree_batch_ffi(
fn inner_verify_merkle_tree_batch_ffi(
batch_ptr: *const u8,
batch_len: usize,
merkle_root: &[u8; 32],
Expand Down Expand Up @@ -53,6 +52,22 @@ pub extern "C" fn verify_merkle_tree_batch_ffi(
computed_batch_merkle_tree.root == *merkle_root
}

#[no_mangle]
pub extern "C" fn verify_merkle_tree_batch_ffi(
batch_ptr: *const u8,
batch_len: usize,
merkle_root: &[u8; 32],
) -> i32 {
let result = std::panic::catch_unwind(|| {
inner_verify_merkle_tree_batch_ffi(batch_ptr, batch_len, merkle_root)
});

match result {
Ok(v) => v as i32,
Err(_) => -1,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -75,7 +90,7 @@ mod tests {
let result =
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);

assert_eq!(result, true);
assert_eq!(result, 1);
}

#[test]
Expand All @@ -92,7 +107,7 @@ mod tests {
let result =
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);

assert_eq!(result, false);
assert_eq!(result, 0);
}

#[test]
Expand All @@ -109,6 +124,6 @@ mod tests {
let result =
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);

assert_eq!(result, false);
assert_eq!(result, 0);
}
}
28 changes: 25 additions & 3 deletions operator/merkle_tree/merkle_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,35 @@ package merkle_tree
*/
import "C"
import "unsafe"
import "fmt"

func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) bool {
func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) (isVerified bool, err error) {
// Here we define the return value on failure
isVerified = false
err = nil
if len(batchBuffer) == 0 {
return false
return isVerified, err
}

// This will catch any go panic
defer func() {
rec := recover()
if rec != nil {
err = fmt.Errorf("Panic was caught while verifying merkle tree batch: %s", rec)
}
}()

batchPtr := (*C.uchar)(unsafe.Pointer(&batchBuffer[0]))
merkleRootPtr := (*C.uchar)(unsafe.Pointer(&merkleRootBuffer[0]))
return (bool)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr))

r := (C.int32_t)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr))

if r == -1 {
err = fmt.Errorf("Panic happened on FFI while verifying merkle tree batch")
return isVerified, err
}

isVerified = (r == 1)

return isVerified, err
}
3 changes: 2 additions & 1 deletion operator/merkle_tree/merkle_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func TestVerifyMerkleTreeBatch(t *testing.T) {
var merkleRoot [32]byte
copy(merkleRoot[:], merkle_root)

if !VerifyMerkleTreeBatch(batchByteValue, merkleRoot) {
verified, err := VerifyMerkleTreeBatch(batchByteValue, merkleRoot)
if err != nil || !verified {
t.Errorf("Batch did not verify Merkle Root")
}

Expand Down
18 changes: 14 additions & 4 deletions operator/pkg/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,13 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool)
results <- verificationResult

case common.SP1:
verificationResult := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode)
o.Logger.Infof("SP1 proof verification result: %t", verificationResult)
results <- verificationResult
verificationResult, err := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode)
o.handleVerificationResult(results, verificationResult, err, "SP1 proof verification")

case common.Risc0:
verificationResult := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof,
verificationResult, err := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof,
verificationData.VmProgramCode, verificationData.PubInput)
o.handleVerificationResult(results, verificationResult, err, "RiscZero proof verification")

o.Logger.Infof("Risc0 proof verification result: %t", verificationResult)
results <- verificationResult
Expand All @@ -512,6 +512,16 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool)
}
}

func (o *Operator) handleVerificationResult(results chan bool, isVerified bool, err error, name string) {
if err != nil {
o.Logger.Errorf("%v failed %v", name, err)
results <- false
} else {
o.Logger.Infof("%v result: %t", name, isVerified)
results <- isVerified
}
}

// VerifyPlonkProofBLS12_381 verifies a PLONK proof using BLS12-381 curve.
func (o *Operator) verifyPlonkProofBLS12_381(proofBytes []byte, pubInputBytes []byte, verificationKeyBytes []byte) bool {
return o.verifyPlonkProof(proofBytes, pubInputBytes, verificationKeyBytes, ecc.BLS12_381)
Expand Down
6 changes: 3 additions & 3 deletions operator/pkg/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ func (o *Operator) getBatchFromDataService(ctx context.Context, batchURL string,

// Checks if downloaded merkle root is the same as the expected one
o.Logger.Infof("Verifying batch merkle tree...")
merkle_root_check := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot)
if !merkle_root_check {
return nil, fmt.Errorf("merkle root check failed")
merkle_root_check, err := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot)
if err != nil || !merkle_root_check {
return nil, fmt.Errorf("Error while verifying merkle tree batch")
}
o.Logger.Infof("Batch merkle tree verified")

Expand Down
2 changes: 1 addition & 1 deletion operator/risc_zero/lib/risc_zero.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <stdbool.h>
#include <stdint.h>

bool verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len);
int32_t verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len);
35 changes: 30 additions & 5 deletions operator/risc_zero/lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use log::error;
use risc0_zkvm::{InnerReceipt, Receipt};

#[no_mangle]
pub extern "C" fn verify_risc_zero_receipt_ffi(
fn inner_verify_risc_zero_receipt_ffi(
inner_receipt_bytes: *const u8,
inner_receipt_len: u32,
image_id: *const u8,
Expand Down Expand Up @@ -43,6 +42,32 @@ pub extern "C" fn verify_risc_zero_receipt_ffi(
false
}

#[no_mangle]
pub extern "C" fn verify_risc_zero_receipt_ffi(
inner_receipt_bytes: *const u8,
inner_receipt_len: u32,
image_id: *const u8,
image_id_len: u32,
public_input: *const u8,
public_input_len: u32,
) -> i32 {
let result = std::panic::catch_unwind(|| {
inner_verify_risc_zero_receipt_ffi(
inner_receipt_bytes,
inner_receipt_len,
image_id,
image_id_len,
public_input,
public_input_len,
)
});

match result {
Ok(v) => v as i32,
Err(_) => -1,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -69,7 +94,7 @@ mod tests {
public_input,
PUBLIC_INPUT.len() as u32,
);
assert!(result)
assert_eq!(result, 1)
}

#[test]
Expand All @@ -86,7 +111,7 @@ mod tests {
public_input,
PUBLIC_INPUT.len() as u32,
);
assert!(!result)
assert_eq!(result, 0)
}

#[test]
Expand All @@ -103,6 +128,6 @@ mod tests {
public_input,
0,
);
assert!(!result)
assert_eq!(result, 0)
}
}
38 changes: 30 additions & 8 deletions operator/risc_zero/risc_zero.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,44 @@ package risc_zero
#include "lib/risc_zero.h"
*/
import "C"
import (
"unsafe"
)
import "unsafe"
import "fmt"

func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) (isVerified bool, err error) {
// Here we define the return value on failure
isVerified = false
err = nil

func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) bool {
if len(innerReceiptBuffer) == 0 || len(imageIdBuffer) == 0 {
return false
return isVerified, err
}

// This will catch any go panic
defer func() {
rec := recover()
if rec != nil {
err = fmt.Errorf("Panic was caught while verifying risc0 proof: %s", rec)
}
}()

receiptPtr := (*C.uchar)(unsafe.Pointer(&innerReceiptBuffer[0]))
imageIdPtr := (*C.uchar)(unsafe.Pointer(&imageIdBuffer[0]))

r := (C.int32_t)(0)

if len(publicInputBuffer) == 0 { // allow empty public input
return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0)))
r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0)))
} else {
publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0]))
r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer))))
}

publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0]))
return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer))))
if r == -1 {
err = fmt.Errorf("Panic happened on FFI while verifying risc0 proof")
return isVerified, err
}

isVerified = (r == 1)

return isVerified, err
}
4 changes: 2 additions & 2 deletions operator/risc_zero/risc_zero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestFibonacciRiscZeroProofVerifies(t *testing.T) {
if err != nil {
t.Errorf("could not open public input file: %s", err)
}

if !risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes) {
verified, err := risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes)
if err != nil || !verified {
t.Errorf("proof did not verify")
}
}
2 changes: 1 addition & 1 deletion operator/sp1/lib/sp1.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <stdbool.h>
#include <stdint.h>

bool verify_sp1_proof_ffi(unsigned char *proof_buffer, uint32_t proof_len,
int32_t verify_sp1_proof_ffi(unsigned char *proof_buffer, uint32_t proof_len,
unsigned char *elf_buffer, uint32_t elf_len);
24 changes: 20 additions & 4 deletions operator/sp1/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ lazy_static! {
static ref PROVER_CLIENT: ProverClient = ProverClient::new();
}

#[no_mangle]
pub extern "C" fn verify_sp1_proof_ffi(
fn inner_verify_sp1_proof_ffi(
proof_bytes: *const u8,
proof_len: u32,
elf_bytes: *const u8,
Expand Down Expand Up @@ -35,6 +34,23 @@ pub extern "C" fn verify_sp1_proof_ffi(
false
}

#[no_mangle]
pub extern "C" fn verify_sp1_proof_ffi(
proof_bytes: *const u8,
proof_len: u32,
elf_bytes: *const u8,
elf_len: u32,
) -> i32 {
let result = std::panic::catch_unwind(|| {
inner_verify_sp1_proof_ffi(proof_bytes, proof_len, elf_bytes, elf_len)
});

match result {
Ok(v) => v as i32,
Err(_) => -1,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -49,7 +65,7 @@ mod tests {

let result =
verify_sp1_proof_ffi(proof_bytes, PROOF.len() as u32, elf_bytes, ELF.len() as u32);
assert!(result)
assert_eq!(result, 1)
}

#[test]
Expand All @@ -63,6 +79,6 @@ mod tests {
elf_bytes,
ELF.len() as u32,
);
assert!(!result)
assert_eq!(result, 0)
}
}
27 changes: 24 additions & 3 deletions operator/sp1/sp1.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,35 @@ package sp1
*/
import "C"
import "unsafe"
import "fmt"

func VerifySp1Proof(proofBuffer []byte, elfBuffer []byte) bool {
func VerifySp1Proof(proofBuffer []byte, elfBuffer []byte) (isVerified bool, err error) {
// Here we define the return value on failure
isVerified = false
err = nil
if len(proofBuffer) == 0 || len(elfBuffer) == 0 {
return false
return isVerified, err
}

// This will catch any go panic
defer func() {
rec := recover()
if rec != nil {
err = fmt.Errorf("Panic was caught while verifying sp1 proof: %s", rec)
}
}()

proofPtr := (*C.uchar)(unsafe.Pointer(&proofBuffer[0]))
elfPtr := (*C.uchar)(unsafe.Pointer(&elfBuffer[0]))

return (bool)(C.verify_sp1_proof_ffi(proofPtr, (C.uint32_t)(len(proofBuffer)), elfPtr, (C.uint32_t)(len(elfBuffer))))
r := (C.int32_t)(C.verify_sp1_proof_ffi(proofPtr, (C.uint32_t)(len(proofBuffer)), elfPtr, (C.uint32_t)(len(elfBuffer))))

if r == -1 {
err = fmt.Errorf("Panic happened on FFI while verifying sp1 proof")
return isVerified, err
}

isVerified = (r == 1)

return isVerified, err
}
Loading
Loading