From 4a8375d1873560241ae8eea96230a42635ed1764 Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Thu, 22 Feb 2024 16:28:25 +0100 Subject: [PATCH 1/2] test: fix tests sensitive to message ordering --- .../examples/server_api_precomputed.py | 8 +- ferveo-python/examples/server_api_simple.py | 6 + ferveo-python/test/test_ferveo.py | 24 +- ferveo-tdec/src/context.rs | 2 +- ferveo-tdec/src/decryption.rs | 14 +- ferveo-wasm/tests/node.rs | 27 +- ferveo/src/api.rs | 133 +++++---- ferveo/src/bindings_python.rs | 9 +- ferveo/src/dkg.rs | 36 +-- ferveo/src/lib.rs | 236 ++++++++------- ferveo/src/pvss.rs | 38 ++- ferveo/src/refresh.rs | 282 +++++++++++------- ferveo/src/test_common.rs | 3 +- 13 files changed, 478 insertions(+), 340 deletions(-) diff --git a/ferveo-python/examples/server_api_precomputed.py b/ferveo-python/examples/server_api_precomputed.py index a37ad573..77e21e61 100644 --- a/ferveo-python/examples/server_api_precomputed.py +++ b/ferveo-python/examples/server_api_precomputed.py @@ -39,6 +39,9 @@ def gen_eth_addr(i: int) -> str: ) messages.append(ValidatorMessage(sender, dkg.generate_transcript())) +# We only need `shares_num` messages to aggregate the transcript +messages = messages[:shares_num] + # Every validator can aggregate the transcripts dkg = Dkg( tau=tau, @@ -84,9 +87,12 @@ def gen_eth_addr(i: int) -> str: ) decryption_shares.append(decryption_share) +# We need `shares_num` decryption shares in precomputed variant +# TODO: This fails if shares_num != validators_num +decryption_shares = decryption_shares[:validators_num] + # Now, the decryption share can be used to decrypt the ciphertext # This part is in the client API - shared_secret = combine_decryption_shares_precomputed(decryption_shares) # The client should have access to the public parameters of the DKG diff --git a/ferveo-python/examples/server_api_simple.py b/ferveo-python/examples/server_api_simple.py index beda8133..972f3c45 100644 --- a/ferveo-python/examples/server_api_simple.py +++ b/ferveo-python/examples/server_api_simple.py @@ -40,6 +40,9 @@ def gen_eth_addr(i: int) -> str: ) messages.append(ValidatorMessage(sender, dkg.generate_transcript())) +# We only need `shares_num` messages to aggregate the transcript +messages = messages[:shares_num] + # Now that every validator holds a dkg instance and a transcript for every other validator, # every validator can aggregate the transcripts me = validators[0] @@ -90,6 +93,9 @@ def gen_eth_addr(i: int) -> str: ) decryption_shares.append(decryption_share) +# We only need `threshold` decryption shares in simple variant +decryption_shares = decryption_shares[:security_threshold] + # Now, the decryption share can be used to decrypt the ciphertext # This part is in the client API diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index 82cbc4f1..45afe800 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -39,7 +39,7 @@ def combine_shares_for_variant(v: FerveoVariant, decryption_shares): def scenario_for_variant( - variant: FerveoVariant, shares_num, validators_num, threshold, shares_to_use + variant: FerveoVariant, shares_num, validators_num, threshold, dec_shares_to_use ): if variant not in [FerveoVariant.Simple, FerveoVariant.Precomputed]: raise ValueError("Unknown variant: " + variant) @@ -47,10 +47,11 @@ def scenario_for_variant( if validators_num < shares_num: raise ValueError("validators_num must be >= shares_num") - if variant == FerveoVariant.Precomputed and shares_to_use != validators_num: - raise ValueError( - "In precomputed variant, shares_to_use must be equal to validators_num" - ) + # TODO: Validate that + # if variant == FerveoVariant.Precomputed and dec_shares_to_use != validators_num: + # raise ValueError( + # "In precomputed variant, dec_shares_to_use must be equal to validators_num" + # ) tau = 1 validator_keypairs = [Keypair.random() for _ in range(0, validators_num)] @@ -72,6 +73,9 @@ def scenario_for_variant( ) messages.append(ValidatorMessage(sender, dkg.generate_transcript())) + # We only need `shares_num` messages to aggregate the transcript + messages = messages[:shares_num] + # Both client and server should be able to verify the aggregated transcript dkg = Dkg( tau=tau, @@ -113,7 +117,7 @@ def scenario_for_variant( decryption_shares.append(decryption_share) # We are limiting the number of decryption shares to use for testing purposes - # decryption_shares = decryption_shares[:shares_to_use] + decryption_shares = decryption_shares[:dec_shares_to_use] # Client combines the decryption shares and decrypts the ciphertext shared_secret = combine_shares_for_variant(variant, decryption_shares) @@ -141,7 +145,7 @@ def test_simple_tdec_has_enough_messages(): shares_num=shares_num, validators_num=validators_num, threshold=threshold, - shares_to_use=threshold, + dec_shares_to_use=threshold, ) @@ -154,7 +158,7 @@ def test_simple_tdec_doesnt_have_enough_messages(): shares_num=shares_num, validators_num=validators_num, threshold=threshold, - shares_to_use=validators_num - 1, + dec_shares_to_use=validators_num - 1, ) @@ -167,7 +171,7 @@ def test_precomputed_tdec_has_enough_messages(): shares_num=shares_num, validators_num=validators_num, threshold=threshold, - shares_to_use=validators_num, + dec_shares_to_use=validators_num, ) @@ -180,7 +184,7 @@ def test_precomputed_tdec_doesnt_have_enough_messages(): shares_num=shares_num, validators_num=validators_num, threshold=threshold, - shares_to_use=threshold - 1, + dec_shares_to_use=threshold - 1, ) diff --git a/ferveo-tdec/src/context.rs b/ferveo-tdec/src/context.rs index 6e565188..ed7faee0 100644 --- a/ferveo-tdec/src/context.rs +++ b/ferveo-tdec/src/context.rs @@ -100,7 +100,7 @@ impl PrivateDecryptionContextSimple { .collect::>(); let lagrange_coeffs = prepare_combine_simple::(&domain); - DecryptionSharePrecomputed::new( + DecryptionSharePrecomputed::create( self.index, &self.setup_params.b, &self.private_key_share, diff --git a/ferveo-tdec/src/decryption.rs b/ferveo-tdec/src/decryption.rs index 316d82f1..dec3ed78 100644 --- a/ferveo-tdec/src/decryption.rs +++ b/ferveo-tdec/src/decryption.rs @@ -72,6 +72,9 @@ impl ValidatorShareChecksum { } } +/// A decryption share for a simple variant of the threshold decryption scheme. +/// In this variant, the decryption share require additional computation on the +/// client side int order to be combined. #[serde_as] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct DecryptionShareSimple { @@ -141,6 +144,11 @@ impl DecryptionShareSimple { } } +/// A decryption share for a precomputed variant of the threshold decryption scheme. +/// In this variant, the decryption share is precomputed and can be combined +/// without additional computation on the client side. +/// The downside is that the threshold of decryption shares required to decrypt +/// is equal to the number of private key shares in the scheme. #[serde_as] #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct DecryptionSharePrecomputed { @@ -155,7 +163,9 @@ pub struct DecryptionSharePrecomputed { } impl DecryptionSharePrecomputed { - pub fn new( + /// Create a decryption share from the given parameters. + /// This function checks that the ciphertext is valid. + pub fn create( validator_index: usize, validator_decryption_key: &E::ScalarField, private_key_share: &PrivateKeyShare, @@ -174,6 +184,8 @@ impl DecryptionSharePrecomputed { ) } + /// Create a decryption share from the given parameters. + /// This function does not check that the ciphertext is valid. pub fn create_unchecked( validator_index: usize, validator_decryption_key: &E::ScalarField, diff --git a/ferveo-wasm/tests/node.rs b/ferveo-wasm/tests/node.rs index 5d4ffbb4..d3a5ea43 100644 --- a/ferveo-wasm/tests/node.rs +++ b/ferveo-wasm/tests/node.rs @@ -1,4 +1,4 @@ -//! Test suite for the Nodejs. +//! Test suite for the Node.js. extern crate wasm_bindgen_test; @@ -45,7 +45,6 @@ fn setup_dkg( ) .unwrap(); let transcript = validator_dkg.generate_transcript().unwrap(); - ValidatorMessage::new(sender, &transcript).unwrap() }); @@ -61,6 +60,8 @@ fn setup_dkg( ) .unwrap(); + // We only need `shares_num` messages to aggregate the transcripts + let messages = messages.take(shares_num as usize).collect::>(); let messages_js = into_js_array(messages); // Server can aggregate the transcripts and verify them @@ -125,7 +126,6 @@ fn tdec_simple() { let is_valid = aggregate.verify(validators_num, &messages_js).unwrap(); assert!(is_valid); - aggregate .create_decryption_share_simple( &dkg, @@ -135,16 +135,16 @@ fn tdec_simple() { ) .unwrap() }) + // We only need `security_threshold` decryption shares in simple variant + .take(security_threshold as usize) .collect::>(); - let decryption_shares_js = into_js_array(decryption_shares); - // Now, the decryption share can be used to decrypt the ciphertext - // This part is in the client API + let decryption_shares_js = into_js_array(decryption_shares); + // Now, decryption shares can be used to decrypt the ciphertext + // This part happens in the client API let shared_secret = combine_decryption_shares_simple(&decryption_shares_js).unwrap(); - - // The client should have access to the public parameters of the DKG let plaintext = decrypt_with_shared_secret(&ciphertext, &aad, &shared_secret) .unwrap(); @@ -183,7 +183,6 @@ fn tdec_precomputed() { let is_valid = aggregate.verify(validators_num, &messages_js).unwrap(); assert!(is_valid); - aggregate .create_decryption_share_precomputed( &dkg, @@ -193,17 +192,17 @@ fn tdec_precomputed() { ) .unwrap() }) + // We need `shares_num` decryption shares in precomputed variant + // TODO: This fails if shares_num != validators_num + .take(validators_num as usize) .collect::>(); let decryption_shares_js = into_js_array(decryption_shares); - // Now, the decryption share can be used to decrypt the ciphertext - // This part is in the client API - + // Now, decryption shares can be used to decrypt the ciphertext + // This part happens in the client API let shared_secret = combine_decryption_shares_precomputed(&decryption_shares_js) .unwrap(); - - // The client should have access to the public parameters of the DKG let plaintext = decrypt_with_shared_secret(&ciphertext, &aad, &shared_secret) .unwrap(); diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index 923de60a..7ab62c56 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -1,4 +1,4 @@ -use std::{fmt, io}; +use std::{collections::HashMap, fmt, io}; use ark_ec::CurveGroup; use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; @@ -25,7 +25,7 @@ use crate::bindings_python; use crate::bindings_wasm; pub use crate::EthereumAddress; use crate::{ - do_verify_aggregation, DomainPoint, Error, PubliclyVerifiableParams, + do_verify_aggregation, Error, PubliclyVerifiableParams, PubliclyVerifiableSS, Result, }; @@ -34,6 +34,7 @@ pub type Keypair = ferveo_common::Keypair; pub type Validator = crate::Validator; pub type Transcript = PubliclyVerifiableSS; pub type ValidatorMessage = (Validator, Transcript); +pub type DomainPoint = crate::DomainPoint; // Normally, we would use a custom trait for this, but we can't because // the `arkworks` will not let us create a blanket implementation for G1Affine @@ -239,7 +240,7 @@ impl Dkg { &self.0.me } - pub fn domain_points(&self) -> Vec> { + pub fn domain_points(&self) -> Vec { self.0.domain_points() } } @@ -369,11 +370,10 @@ impl AggregatedTranscript { pub struct DecryptionShareSimple { share: ferveo_tdec::api::DecryptionShareSimple, #[serde_as(as = "serialization::SerdeAs")] - domain_point: DomainPoint, + domain_point: DomainPoint, } pub fn combine_shares_simple(shares: &[DecryptionShareSimple]) -> SharedSecret { - // Pick domain points that are corresponding to the shares we have. let domain_points: Vec<_> = shares.iter().map(|s| s.domain_point).collect(); let lagrange_coefficients = prepare_combine_simple::(&domain_points); @@ -387,6 +387,7 @@ pub fn combine_shares_simple(shares: &[DecryptionShareSimple]) -> SharedSecret { pub struct SharedSecret(pub ferveo_tdec::api::SharedSecret); #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +// TODO: Use refresh::ShareRecoveryUpdate instead of ferveo_tdec::PrivateKeyShare pub struct ShareRecoveryUpdate(pub ferveo_tdec::PrivateKeyShare); impl ShareRecoveryUpdate { @@ -395,21 +396,23 @@ impl ShareRecoveryUpdate { pub fn create_share_updates( // TODO: Decouple from Dkg? We don't need any specific Dkg instance here, just some params etc dkg: &Dkg, - x_r: &DomainPoint, - ) -> Result> { + x_r: &DomainPoint, + ) -> Result> { let rng = &mut thread_rng(); - let updates = + let update_map = crate::refresh::ShareRecoveryUpdate::create_share_updates( - &dkg.0.domain_points(), + &dkg.0.domain_point_map(), &dkg.0.pvss_params.h.into_affine(), x_r, dkg.0.dkg_params.security_threshold(), rng, ) - .iter() - .map(|update| ShareRecoveryUpdate(update.0.clone())) + .into_iter() + .map(|(share_index, share_update)| { + (share_index, ShareRecoveryUpdate(share_update.0.clone())) + }) .collect(); - Ok(updates) + Ok(update_map) } pub fn to_bytes(&self) -> Result> { @@ -426,17 +429,21 @@ impl ShareRecoveryUpdate { pub struct ShareRefreshUpdate(pub crate::ShareRefreshUpdate); impl ShareRefreshUpdate { - pub fn create_share_updates(dkg: &Dkg) -> Result> { + pub fn create_share_updates( + dkg: &Dkg, + ) -> Result> { let rng = &mut thread_rng(); let updates = crate::refresh::ShareRefreshUpdate::create_share_updates( - &dkg.0.domain_points(), + &dkg.0.domain_point_map(), &dkg.0.pvss_params.h.into_affine(), dkg.0.dkg_params.security_threshold(), rng, ) .into_iter() - .map(ShareRefreshUpdate) - .collect(); + .map(|(share_index, share_update)| { + (share_index, ShareRefreshUpdate(share_update)) + }) + .collect::>(); Ok(updates) } @@ -499,21 +506,20 @@ impl PrivateKeyShare { /// Recover a private key share from updated private key shares pub fn recover_share_from_updated_private_shares( - x_r: &DomainPoint, - domain_points: &[DomainPoint], - updated_shares: &[UpdatedPrivateKeyShare], + x_r: &DomainPoint, + domain_points: &HashMap, + updated_shares: &HashMap, ) -> Result { - let updated_shares: Vec<_> = updated_shares + let updated_shares = updated_shares .iter() - .cloned() - .map(|updated| updated.0) - .collect(); + .map(|(k, v)| (*k, v.0.clone())) + .collect::>(); let share = crate::PrivateKeyShare::recover_share_from_updated_private_shares( x_r, domain_points, - &updated_shares[..], - ); + &updated_shares, + )?; Ok(PrivateKeyShare(share)) } @@ -544,7 +550,7 @@ impl PrivateKeyShare { aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points: &[DomainPoint], ) -> Result { let share = self.0.create_decryption_share_simple_precomputed( &ciphertext_header.0, @@ -604,7 +610,7 @@ mod test_ferveo_api { // Each validator holds their own DKG instance and generates a transcript every // validator, including themselves - let messages: Vec<_> = validators + let mut messages: Vec<_> = validators .iter() .map(|sender| { let dkg = Dkg::new( @@ -618,7 +624,7 @@ mod test_ferveo_api { (sender.clone(), dkg.0.generate_transcript(rng).unwrap()) }) .collect(); - + messages.shuffle(rng); (messages, validators, validator_keypairs) } @@ -639,7 +645,6 @@ mod test_ferveo_api { // In precomputed variant, the security threshold is equal to the number of shares let security_threshold = shares_num; - let (messages, validators, validator_keypairs) = make_test_inputs( rng, TAU, @@ -647,14 +652,16 @@ mod test_ferveo_api { shares_num, validators_num, ); + // We only need `shares_num` transcripts to aggregate + let messages = &messages[..shares_num as usize]; // Every validator can aggregate the transcripts let me = validators[0].clone(); let dkg = Dkg::new(TAU, shares_num, security_threshold, &validators, &me) .unwrap(); - let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated.verify(validators_num, &messages).unwrap()); + let pvss_aggregated = dkg.aggregate_transcripts(messages).unwrap(); + assert!(pvss_aggregated.verify(validators_num, messages).unwrap()); // At this point, any given validator should be able to provide a DKG public key let dkg_public_key = pvss_aggregated.public_key(); @@ -678,9 +685,9 @@ mod test_ferveo_api { ) .unwrap(); let aggregate = - dkg.aggregate_transcripts(&messages).unwrap(); + dkg.aggregate_transcripts(messages).unwrap(); assert!(pvss_aggregated - .verify(validators_num, &messages) + .verify(validators_num, messages) .unwrap()); // And then each validator creates their own decryption share @@ -710,7 +717,6 @@ mod test_ferveo_api { // Since we're using a precomputed variant, we need all the shares to be able to decrypt // So if we remove one share, we should not be able to decrypt - let decryption_shares = decryption_shares[..shares_num as usize - 1].to_vec(); let shared_secret = share_combine_precomputed(&decryption_shares); @@ -727,8 +733,8 @@ mod test_ferveo_api { #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_server_api_tdec_simple(shares_num: u32, validators_num: u32) { let rng = &mut StdRng::seed_from_u64(0); - let security_threshold = shares_num / 2 + 1; + let security_threshold = shares_num / 2 + 1; let (messages, validators, validator_keypairs) = make_test_inputs( rng, TAU, @@ -736,6 +742,8 @@ mod test_ferveo_api { shares_num, validators_num, ); + // We only need `shares_num` transcripts to aggregate + let messages = &messages[..shares_num as usize]; // Now that every validator holds a dkg instance and a transcript for every other validator, // every validator can aggregate the transcripts @@ -747,8 +755,8 @@ mod test_ferveo_api { &validators[0], ) .unwrap(); - let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated.verify(validators_num, &messages).unwrap()); + let pvss_aggregated = dkg.aggregate_transcripts(messages).unwrap(); + assert!(pvss_aggregated.verify(validators_num, messages).unwrap()); // At this point, any given validator should be able to provide a DKG public key let public_key = pvss_aggregated.public_key(); @@ -771,9 +779,9 @@ mod test_ferveo_api { ) .unwrap(); let aggregate = - dkg.aggregate_transcripts(&messages).unwrap(); + dkg.aggregate_transcripts(messages).unwrap(); assert!(aggregate - .verify(validators_num, &messages) + .verify(validators_num, messages) .unwrap()); aggregate .create_decryption_share_simple( @@ -829,6 +837,8 @@ mod test_ferveo_api { shares_num, validators_num, ); + // We only need `shares_num` transcripts to aggregate + let messages = &messages[..shares_num as usize]; // Now that every validator holds a dkg instance and a transcript for every other validator, // every validator can aggregate the transcripts @@ -836,8 +846,8 @@ mod test_ferveo_api { let dkg = Dkg::new(TAU, shares_num, security_threshold, &validators, &me) .unwrap(); - let good_aggregate = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(good_aggregate.verify(validators_num, &messages).is_ok()); + let good_aggregate = dkg.aggregate_transcripts(messages).unwrap(); + assert!(good_aggregate.verify(validators_num, messages).is_ok()); // Test negative cases @@ -846,7 +856,7 @@ mod test_ferveo_api { // Should fail if the number of validators is less than the number of messages assert!(matches!( - good_aggregate.verify(messages.len() as u32 - 1, &messages), + good_aggregate.verify(messages.len() as u32 - 1, messages), Err(Error::InvalidAggregateVerificationParameters(_, _)) )); @@ -868,7 +878,7 @@ mod test_ferveo_api { let insufficient_aggregate = dkg.aggregate_transcripts(not_enough_messages).unwrap(); assert!(matches!( - insufficient_aggregate.verify(validators_num, &messages), + insufficient_aggregate.verify(validators_num, messages), Err(Error::InvalidTranscriptAggregate) )); @@ -919,7 +929,7 @@ mod test_ferveo_api { assert_eq!(mixed_messages.len(), security_threshold as usize); let bad_aggregate = dkg.aggregate_transcripts(&mixed_messages).unwrap(); assert!(matches!( - bad_aggregate.verify(validators_num, &messages), + bad_aggregate.verify(validators_num, messages), Err(Error::InvalidTranscriptAggregate) )); } @@ -939,8 +949,8 @@ mod test_ferveo_api { validators_num, ); - // We only need `security_threshold` transcripts to aggregate - let messages = &messages[..security_threshold as usize]; + // We only need `shares_num` transcripts to aggregate + let messages = &messages[..shares_num as usize]; // Create an aggregated transcript on the client side let good_aggregate = AggregatedTranscript::new(messages).unwrap(); @@ -1100,8 +1110,10 @@ mod test_ferveo_api { .unwrap()); // We need to save this domain point to be user in the recovery testing scenario - let mut domain_points = dkgs[0].domain_points(); - let removed_domain_point = domain_points.pop().unwrap(); + let mut domain_points = dkgs[0].0.domain_point_map(); + let removed_domain_point = domain_points + .remove(&validators.last().unwrap().share_index) + .unwrap(); // Remove one participant from the contexts and all nested structure // to simulate off-boarding a validator @@ -1114,7 +1126,7 @@ mod test_ferveo_api { // and check that the shared secret is still the same. let x_r = if recover_at_random_point { // Onboarding a validator with a completely new private key share - DomainPoint::::rand(rng) + DomainPoint::rand(rng) } else { // Onboarding a validator with a private key share recovered from the removed validator removed_domain_point @@ -1136,16 +1148,14 @@ mod test_ferveo_api { // Participants share updates and update their shares // Now, every participant separately: - let updated_shares: Vec<_> = dkgs + let updated_shares: HashMap = dkgs .iter() .map(|validator_dkg| { // Current participant receives updates from other participants let updates_for_participant: Vec<_> = share_updates .values() .map(|updates| { - updates - .get(validator_dkg.me().share_index as usize) - .unwrap() + updates.get(&validator_dkg.me().share_index).unwrap() }) .cloned() .collect(); @@ -1156,7 +1166,7 @@ mod test_ferveo_api { .unwrap(); // And creates updated private key shares - aggregated_transcript + let updated_key_share = aggregated_transcript .get_private_key_share( validator_keypair, validator_dkg.me().share_index, @@ -1165,7 +1175,8 @@ mod test_ferveo_api { .create_updated_private_key_share_for_recovery( &updates_for_participant, ) - .unwrap() + .unwrap(); + (validator_dkg.me().share_index, updated_key_share) }) .collect(); @@ -1228,11 +1239,15 @@ mod test_ferveo_api { ) .unwrap(); decryption_shares.push(new_decryption_share); - domain_points.push(x_r); + domain_points.insert(new_validator_share_index, x_r); assert_eq!(domain_points.len(), validators_num as usize); assert_eq!(decryption_shares.len(), validators_num as usize); - let domain_points = &domain_points[..security_threshold as usize]; + let domain_points = domain_points + .values() + .take(security_threshold as usize) + .cloned() + .collect::>(); let decryption_shares = &decryption_shares[..security_threshold as usize]; assert_eq!(domain_points.len(), security_threshold as usize); @@ -1290,9 +1305,7 @@ mod test_ferveo_api { let updates_for_participant: Vec<_> = share_updates .values() .map(|updates| { - updates - .get(validator_dkg.me().share_index as usize) - .unwrap() + updates.get(&validator_dkg.me().share_index).unwrap() }) .cloned() .collect(); diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index ecdb9e56..0689a495 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -816,7 +816,7 @@ mod test_ferveo_python { .collect(); // Each validator holds their own DKG instance and generates a transcript every - // every validator, including themselves + // validator, including themselves let messages: Vec<_> = validators .iter() .cloned() @@ -864,7 +864,7 @@ mod test_ferveo_python { ) .unwrap(); - // Lets say that we've only received `security_threshold` transcripts + // Let's say that we've only received `security_threshold` transcripts let messages = messages[..security_threshold as usize].to_vec(); let pvss_aggregated = dkg.aggregate_transcripts(messages.clone()).unwrap(); @@ -912,7 +912,6 @@ mod test_ferveo_python { // This part is part of the client API let shared_secret = combine_decryption_shares_precomputed(decryption_shares); - let plaintext = decrypt_with_shared_secret(&ciphertext, AAD, &shared_secret) .unwrap(); @@ -942,7 +941,7 @@ mod test_ferveo_python { ) .unwrap(); - // Lets say that we've only receives `security_threshold` transcripts + // Let's say that we've only receives `security_threshold` transcripts let messages = messages[..security_threshold as usize].to_vec(); let pvss_aggregated = dkg.aggregate_transcripts(messages.clone()).unwrap(); @@ -989,9 +988,7 @@ mod test_ferveo_python { // Now, the decryption share can be used to decrypt the ciphertext // This part is part of the client API - let shared_secret = combine_decryption_shares_simple(decryption_shares); - let plaintext = decrypt_with_shared_secret(&ciphertext, AAD, &shared_secret) .unwrap(); diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index 087b3069..d2e825a7 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; use ark_ec::pairing::Pairing; use ark_poly::EvaluationDomain; @@ -155,8 +155,8 @@ impl PubliclyVerifiableDkg { /// Return a domain point for the share_index pub fn get_domain_point(&self, share_index: u32) -> Result> { - self.domain_points() - .get(share_index as usize) + self.domain_point_map() + .get(&share_index) .ok_or_else(|| Error::InvalidShareIndex(share_index)) .copied() } @@ -167,6 +167,15 @@ impl PubliclyVerifiableDkg { self.domain.elements().take(self.validators.len()).collect() } + /// Return a map of domain points for the DKG + pub fn domain_point_map(&self) -> HashMap> { + self.domain_points() + .iter() + .enumerate() + .map(|(i, point)| (i as u32, *point)) + .collect::>() + } + /// Verify PVSS transcripts against the set of validators in the DKG fn verify_transcripts( &self, @@ -189,18 +198,14 @@ impl PubliclyVerifiableDkg { transcript_set.insert(transcript.clone()); } - if validator_set.len() > self.validators.len() { + if validator_set.len() > self.validators.len() + || transcript_set.len() > self.validators.len() + { return Err(Error::TooManyTranscripts( self.validators.len() as u32, validator_set.len() as u32, )); } - if transcript_set.len() > self.validators.len() { - return Err(Error::TooManyTranscripts( - self.validators.len() as u32, - transcript_set.len() as u32, - )); - } Ok(()) } @@ -233,7 +238,6 @@ mod test_dkg_init { &unknown_validator, ) .unwrap_err(); - assert_eq!(err.to_string(), "Expected validator to be a part of the DKG validator set: 0x0000000000000000000000000000000000000005") } } @@ -278,7 +282,6 @@ mod test_dealing { let rng = &mut ark_std::test_rng(); let (dkg, _) = setup_dkg(0); let messages = make_messages(rng, &dkg); - assert!(dkg.verify_transcripts(&messages).is_ok()); } @@ -334,14 +337,11 @@ mod test_dealing { /// Test aggregating transcripts into final key #[cfg(test)] mod test_aggregation { - use test_case::test_case; - use crate::test_common::*; /// Test that if the security threshold is met, we can create a final key - #[test_case(4, 4; "number of validators equal to the number of shares")] - #[test_case(4, 6; "number of validators greater than the number of shares")] - fn test_aggregate(_shares_num: u32, _validators_num: u32) { + #[test] + fn test_aggregate() { let rng = &mut ark_std::test_rng(); let (dkg, _) = setup_dkg(0); let all_messages = make_messages(rng, &dkg); @@ -356,7 +356,7 @@ mod test_aggregation { let enough_messages = all_messages .iter() - .take((dkg.dkg_params.security_threshold) as usize) + .take(dkg.dkg_params.security_threshold as usize) .cloned() .collect::>(); let good_aggregate_1 = diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index 832f0564..67e501fe 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -114,15 +114,14 @@ mod test_dkg_full { use ark_bls12_381::{Bls12_381 as E, Fr, G1Affine}; use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{UniformRand, Zero}; - use ark_poly::EvaluationDomain; use ark_std::test_rng; use ferveo_common::Keypair; use ferveo_tdec::{ self, DecryptionSharePrecomputed, DecryptionShareSimple, SecretBox, SharedSecret, }; - use itertools::izip; - use rand::seq::SliceRandom; + use itertools::{izip, Itertools}; + use rand::{seq::SliceRandom, Rng}; use test_case::test_case; use super::*; @@ -144,7 +143,7 @@ mod test_dkg_full { assert!(pvss_aggregated .aggregate .verify_aggregation(dkg, transcripts) - .is_ok()); + .unwrap()); let decryption_shares: Vec> = validator_keypairs @@ -163,13 +162,11 @@ mod test_dkg_full { ) .unwrap() }) + // We take only the first `security_threshold` decryption shares + .take(dkg.dkg_params.security_threshold() as usize) .collect(); - let domain_points = &dkg - .domain - .elements() - .take(decryption_shares.len()) - .collect::>(); + let domain_points = &dkg.domain_points()[..decryption_shares.len()]; assert_eq!(domain_points.len(), decryption_shares.len()); let lagrange_coeffs = @@ -178,7 +175,6 @@ mod test_dkg_full { &decryption_shares, &lagrange_coeffs, ); - (pvss_aggregated, decryption_shares, shared_secret) } @@ -195,8 +191,11 @@ mod test_dkg_full { shares_num, validators_num, ); - let transcripts = - messages.iter().map(|m| m.1.clone()).collect::>(); + let transcripts = messages + .iter() + .take(shares_num as usize) + .map(|m| m.1.clone()) + .collect::>(); let public_key = AggregatedTranscript::from_transcripts(&transcripts) .unwrap() .public_key; @@ -228,26 +227,30 @@ mod test_dkg_full { #[test_case(4, 4; "number of shares (validators) is a power of 2")] #[test_case(7, 7; "number of shares (validators) is not a power of 2")] - #[test_case(4, 6; "number of validators greater than the number of shares")] + // TODO: This test fails: + // #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_dkg_simple_tdec_precomputed(shares_num: u32, validators_num: u32) { let rng = &mut test_rng(); // In precomputed variant, threshold must be equal to shares_num let security_threshold = shares_num; - let (dkg, validator_keypairs, messangers) = + let (dkg, validator_keypairs, messages) = setup_dealt_dkg_with_n_validators( security_threshold, shares_num, validators_num, ); - let transcripts = - messangers.iter().map(|m| m.1.clone()).collect::>(); + let transcripts = messages + .iter() + .take(shares_num as usize) + .map(|m| m.1.clone()) + .collect::>(); let pvss_aggregated = AggregatedTranscript::from_transcripts(&transcripts).unwrap(); - pvss_aggregated + assert!(pvss_aggregated .aggregate .verify_aggregation(&dkg, &transcripts) - .unwrap(); + .unwrap()); let public_key = pvss_aggregated.public_key; let ciphertext = ferveo_tdec::encrypt::( SecretBox::new(MSG.to_vec()), @@ -257,12 +260,6 @@ mod test_dkg_full { ) .unwrap(); - let domain_points = dkg - .domain - .elements() - .take(validator_keypairs.len()) - .collect::>(); - let mut decryption_shares: Vec> = validator_keypairs .iter() @@ -277,18 +274,20 @@ mod test_dkg_full { AAD, validator_keypair, validator.share_index, - &domain_points, + &dkg.domain_points(), ) .unwrap() }) + // We take only the first `security_threshold` decryption shares + .take(dkg.dkg_params.security_threshold() as usize) .collect(); + + // Order of decryption shares is not important in the precomputed variant decryption_shares.shuffle(rng); - assert_eq!(domain_points.len(), decryption_shares.len()); + // Decrypt with precomputed variant let shared_secret = ferveo_tdec::share_combine_precomputed::(&decryption_shares); - - // Combination works, let's decrypt let plaintext = ferveo_tdec::decrypt_with_shared_secret( &ciphertext, AAD, @@ -315,8 +314,11 @@ mod test_dkg_full { shares_num, validators_num, ); - let transcripts = - messages.iter().map(|m| m.1.clone()).collect::>(); + let transcripts = messages + .iter() + .take(shares_num as usize) + .map(|m| m.1.clone()) + .collect::>(); let public_key = AggregatedTranscript::from_transcripts(&transcripts) .unwrap() .public_key; @@ -393,8 +395,11 @@ mod test_dkg_full { shares_num, validators_num, ); - let transcripts = - messages.iter().map(|m| m.1.clone()).collect::>(); + let transcripts = messages + .iter() + .take(shares_num as usize) + .map(|m| m.1.clone()) + .collect::>(); let public_key = AggregatedTranscript::from_transcripts(&transcripts) .unwrap() .public_key; @@ -415,22 +420,27 @@ mod test_dkg_full { &transcripts, ); + // TODO: Rewrite this test so that the offboarding of validator + // is done by recreating a DKG instance with a new set of + // validators from the Coordinator, rather than modifying the + // existing DKG instance. + // Remove one participant from the contexts and all nested structure - let removed_validator_addr = - dkg.validators.keys().last().unwrap().clone(); + let removed_validator_index = rng.gen_range(0..validators_num); + let removed_validator_addr = dkg + .validators + .iter() + .find(|(_, v)| v.share_index == removed_validator_index) + .unwrap() + .1 + .address + .clone(); let mut remaining_validators = dkg.validators.clone(); - remaining_validators - .remove(&removed_validator_addr) - .unwrap(); - - let mut remaining_validator_keypairs = validator_keypairs.clone(); - remaining_validator_keypairs - .pop() - .expect("Should have a keypair"); + remaining_validators.remove(&removed_validator_addr); // Remember to remove one domain point too - let mut domain_points = dkg.domain_points(); - domain_points.pop().unwrap(); + let mut domain_points = dkg.domain_point_map(); + domain_points.remove(&removed_validator_index); // Now, we're going to recover a new share at a random point, // and check that the shared secret is still the same. @@ -438,7 +448,7 @@ mod test_dkg_full { // Our random point: let x_r = Fr::rand(rng); - // Each participant prepares an update for each other participant + // Each participant prepares an update for every other participant let share_updates = remaining_validators .keys() .map(|v_addr| { @@ -456,15 +466,13 @@ mod test_dkg_full { // Participants share updates and update their shares // Now, every participant separately: - let updated_shares: Vec<_> = remaining_validators + let updated_shares: HashMap = remaining_validators .values() .map(|validator| { // Current participant receives updates from other participants - let updates_for_participant: Vec<_> = share_updates + let updates_for_validator: Vec<_> = share_updates .values() - .map(|updates| { - updates.get(validator.share_index as usize).unwrap() - }) + .map(|updates| updates.get(&validator.share_index).unwrap()) .cloned() .collect(); @@ -474,15 +482,17 @@ mod test_dkg_full { .unwrap(); // Creates updated private key shares - AggregatedTranscript::from_transcripts(&transcripts) - .unwrap() - .aggregate - .create_updated_private_key_share( - validator_keypair, - validator.share_index, - updates_for_participant.as_slice(), - ) - .unwrap() + let updated_key_share = + AggregatedTranscript::from_transcripts(&transcripts) + .unwrap() + .aggregate + .create_updated_private_key_share( + validator_keypair, + validator.share_index, + updates_for_validator.as_slice(), + ) + .unwrap(); + (validator.share_index, updated_key_share) }) .collect(); @@ -492,14 +502,17 @@ mod test_dkg_full { &x_r, &domain_points, &updated_shares, - ); + ) + .unwrap(); // Get decryption shares from remaining participants - let mut decryption_shares: Vec> = - remaining_validator_keypairs - .iter() - .enumerate() - .map(|(share_index, validator_keypair)| { + let mut decryption_shares = remaining_validators + .values() + .map(|validator| { + let validator_keypair = validator_keypairs + .get(validator.share_index as usize) + .unwrap(); + let decryption_share = AggregatedTranscript::from_transcripts(&transcripts) .unwrap() .aggregate @@ -507,42 +520,57 @@ mod test_dkg_full { &ciphertext.header().unwrap(), AAD, validator_keypair, - share_index as u32, + validator.share_index, ) - .unwrap() - }) - .collect(); + .unwrap(); + (validator.share_index, decryption_share) + }) + // We take only the first `security_threshold - 1` decryption shares + .take((dkg.dkg_params.security_threshold() - 1) as usize) + .collect::>(); // Create a decryption share from a recovered private key share let new_validator_decryption_key = Fr::rand(rng); - decryption_shares.push( - DecryptionShareSimple::create( - &new_validator_decryption_key, - &recovered_key_share.0, - &ciphertext.header().unwrap(), - AAD, - &dkg.pvss_params.g_inv(), - ) - .unwrap(), - ); - - domain_points.push(x_r); - assert_eq!(domain_points.len(), validators_num as usize); - assert_eq!(decryption_shares.len(), validators_num as usize); - - // TODO: Maybe parametrize this test with [1..] and [..threshold] - let domain_points = &domain_points[..security_threshold as usize]; - let decryption_shares = - &decryption_shares[..security_threshold as usize]; - assert_eq!(domain_points.len(), security_threshold as usize); - assert_eq!(decryption_shares.len(), security_threshold as usize); + let new_decryption_share = DecryptionShareSimple::create( + &new_validator_decryption_key, + &recovered_key_share.0, + &ciphertext.header().unwrap(), + AAD, + &dkg.pvss_params.g_inv(), + ) + .unwrap(); + decryption_shares.insert(removed_validator_index, new_decryption_share); + domain_points.insert(removed_validator_index, x_r); + + // We need to make sure that the domain points and decryption shares are ordered + // by the share index, so that the lagrange basis is calculated correctly + + let mut domain_points_ = vec![]; + let mut decryption_shares_ = vec![]; + for share_index in decryption_shares.keys().sorted() { + domain_points_.push( + *domain_points + .get(share_index) + .ok_or(Error::InvalidShareIndex(*share_index)) + .unwrap(), + ); + decryption_shares_.push( + decryption_shares + .get(share_index) + .ok_or(Error::InvalidShareIndex(*share_index)) + .unwrap() + .clone(), + ); + } + assert_eq!(domain_points_.len(), security_threshold as usize); + assert_eq!(decryption_shares_.len(), security_threshold as usize); - let lagrange = ferveo_tdec::prepare_combine_simple::(domain_points); + let lagrange = + ferveo_tdec::prepare_combine_simple::(&domain_points_); let new_shared_secret = ferveo_tdec::share_combine_simple::( - decryption_shares, + &decryption_shares_, &lagrange, ); - assert_eq!( old_shared_secret, new_shared_secret, "Shared secret reconstruction failed" @@ -565,8 +593,11 @@ mod test_dkg_full { shares_num, validators_num, ); - let transcripts = - messages.iter().map(|m| m.1.clone()).collect::>(); + let transcripts = messages + .iter() + .take(shares_num as usize) + .map(|m| m.1.clone()) + .collect::>(); let public_key = AggregatedTranscript::from_transcripts(&transcripts) .unwrap() .public_key; @@ -593,7 +624,7 @@ mod test_dkg_full { .keys() .map(|v_addr| { let deltas_i = ShareRefreshUpdate::create_share_updates( - &dkg.domain_points(), + &dkg.domain_point_map(), &dkg.pvss_params.h.into_affine(), dkg.dkg_params.security_threshold(), rng, @@ -613,10 +644,7 @@ mod test_dkg_full { let updates_for_participant: Vec<_> = share_updates .values() .map(|updates| { - updates - .get(validator.share_index as usize) - .cloned() - .unwrap() + updates.get(&validator.share_index).cloned().unwrap() }) .collect(); @@ -659,8 +687,15 @@ mod test_dkg_full { ) .unwrap() }) + // We take only the first `security_threshold` decryption shares + .take(dkg.dkg_params.security_threshold() as usize) .collect(); + // Order of decryption shares is not important, but since we are using low-level + // API here to performa a refresh for testing purpose, we will not shuffle + // the shares this time + // decryption_shares.shuffle(rng); + let lagrange = ferveo_tdec::prepare_combine_simple::( &dkg.domain_points()[..security_threshold as usize], ); @@ -668,7 +703,6 @@ mod test_dkg_full { &decryption_shares[..security_threshold as usize], &lagrange, ); - assert_eq!(old_shared_secret, new_shared_secret); } } diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs index dc07d5b7..700db9cb 100644 --- a/ferveo/src/pvss.rs +++ b/ferveo/src/pvss.rs @@ -240,22 +240,28 @@ pub fn do_verify_full( domain.fft_in_place(&mut commitment); // Each validator checks that their share is correct - Ok(validators - .iter() - .zip(pvss_encrypted_shares.iter()) - .enumerate() - .all(|(share_index, (validator, y_i))| { - // TODO: Check #3 is missing - // See #3 in 4.2.3 section of https://eprint.iacr.org/2022/898.pdf - - // Validator checks aggregated shares against commitment - let ek_i = validator.public_key.encryption_key.into_group(); - let a_i = &commitment[share_index]; - // We verify that e(G, Y_i) = e(A_i, ek_i) for validator i - // See #4 in 4.2.3 section of https://eprint.iacr.org/2022/898.pdf - // e(G,Y) = e(A, ek) - E::pairing(pvss_params.g, *y_i) == E::pairing(a_i, ek_i) - })) + for validator in validators { + // TODO: Check #3 is missing + // See #3 in 4.2.3 section of https://eprint.iacr.org/2022/898.pdf + + let y_i = pvss_encrypted_shares + .get(validator.share_index as usize) + .ok_or(Error::InvalidShareIndex(validator.share_index))?; + // Validator checks aggregated shares against commitment + let ek_i = validator.public_key.encryption_key.into_group(); + let a_i = commitment + .get(validator.share_index as usize) + .ok_or(Error::InvalidShareIndex(validator.share_index))?; + // We verify that e(G, Y_i) = e(A_i, ek_i) for validator i + // See #4 in 4.2.3 section of https://eprint.iacr.org/2022/898.pdf + // e(G,Y) = e(A, ek) + let is_valid = E::pairing(pvss_params.g, *y_i) == E::pairing(a_i, ek_i); + if !is_valid { + return Ok(false); + } + } + + Ok(true) } pub fn do_verify_aggregation( diff --git a/ferveo/src/refresh.rs b/ferveo/src/refresh.rs index bf633fa2..0b8a95ef 100644 --- a/ferveo/src/refresh.rs +++ b/ferveo/src/refresh.rs @@ -1,4 +1,4 @@ -use std::{ops::Mul, usize}; +use std::{collections::HashMap, ops::Mul, usize}; use ark_ec::{pairing::Pairing, CurveGroup}; use ark_ff::Zero; @@ -8,7 +8,7 @@ use ferveo_tdec::{ lagrange_basis_at, prepare_combine_simple, CiphertextHeader, DecryptionSharePrecomputed, DecryptionShareSimple, }; -use itertools::zip_eq; +use itertools::{zip_eq, Itertools}; use rand_core::RngCore; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use zeroize::ZeroizeOnDrop; @@ -54,15 +54,35 @@ impl PrivateKeyShare { /// `x_r` is the point at which the share is to be recovered pub fn recover_share_from_updated_private_shares( x_r: &DomainPoint, - domain_points: &[DomainPoint], - updated_private_shares: &[UpdatedPrivateKeyShare], - ) -> PrivateKeyShare { + domain_points: &HashMap>, + updated_shares: &HashMap>, + ) -> Result> { + // Pick the domain points and updated shares according to share index + let mut domain_points_ = vec![]; + let mut updated_shares_ = vec![]; + for share_index in updated_shares.keys().sorted() { + domain_points_.push( + *domain_points + .get(share_index) + .ok_or(Error::InvalidShareIndex(*share_index))?, + ); + updated_shares_.push( + updated_shares + .get(share_index) + .ok_or(Error::InvalidShareIndex(*share_index))? + .0 + .clone(), + ); + } + // Interpolate new shares to recover y_r - let lagrange = lagrange_basis_at::(domain_points, x_r); - let prods = zip_eq(updated_private_shares, lagrange) - .map(|(y_j, l)| y_j.0 .0.mul(l)); + let lagrange = lagrange_basis_at::(&domain_points_, x_r); + let prods = + zip_eq(updated_shares_, lagrange).map(|(y_j, l)| y_j.0.mul(l)); let y_r = prods.fold(E::G2::zero(), |acc, y_j| acc + y_j); - PrivateKeyShare(ferveo_tdec::PrivateKeyShare(y_r.into_affine())) + Ok(PrivateKeyShare(ferveo_tdec::PrivateKeyShare( + y_r.into_affine(), + ))) } pub fn create_decryption_share_simple( @@ -97,7 +117,7 @@ impl PrivateKeyShare { let lagrange_coeff = &lagrange_coeffs .get(share_index as usize) .ok_or(Error::InvalidShareIndex(share_index))?; - DecryptionSharePrecomputed::new( + DecryptionSharePrecomputed::create( share_index as usize, &validator_keypair.decryption_key, &self.0, @@ -154,12 +174,12 @@ impl PrivateKeyShareUpdate for ShareRecoveryUpdate { impl ShareRecoveryUpdate { /// From PSS paper, section 4.2.1, (https://link.springer.com/content/pdf/10.1007/3-540-44750-4_27.pdf) pub fn create_share_updates( - domain_points: &[DomainPoint], + domain_points: &HashMap>, h: &E::G2Affine, x_r: &DomainPoint, threshold: u32, rng: &mut impl RngCore, - ) -> Vec> { + ) -> HashMap> { // Update polynomial has root at x_r prepare_share_updates_with_root::( domain_points, @@ -168,8 +188,8 @@ impl ShareRecoveryUpdate { threshold, rng, ) - .iter() - .map(|p| Self(p.clone())) + .into_iter() + .map(|(share_index, share_update)| (share_index, Self(share_update))) .collect() } } @@ -195,11 +215,11 @@ impl PrivateKeyShareUpdate for ShareRefreshUpdate { impl ShareRefreshUpdate { /// From PSS paper, section 4.2.1, (https://link.springer.com/content/pdf/10.1007/3-540-44750-4_27.pdf) pub fn create_share_updates( - domain_points: &[DomainPoint], + domain_points: &HashMap>, h: &E::G2Affine, threshold: u32, rng: &mut impl RngCore, - ) -> Vec> { + ) -> HashMap> { // Update polynomial has root at 0 prepare_share_updates_with_root::( domain_points, @@ -208,9 +228,10 @@ impl ShareRefreshUpdate { threshold, rng, ) - .iter() - .cloned() - .map(|p| ShareRefreshUpdate(p)) + .into_iter() + .map(|(share_index, share_update)| { + (share_index, ShareRefreshUpdate(share_update)) + }) .collect() } } @@ -221,24 +242,25 @@ impl ShareRefreshUpdate { /// The result is a list of share updates. /// We represent the share updates as `InnerPrivateKeyShare` to avoid dependency on the concrete implementation of `PrivateKeyShareUpdate`. fn prepare_share_updates_with_root( - domain_points: &[DomainPoint], + domain_points: &HashMap>, h: &E::G2Affine, root: &DomainPoint, threshold: u32, rng: &mut impl RngCore, -) -> Vec> { - // Generate a new random polynomial with defined root +) -> HashMap> { + // Generate a new random polynomial with a defined root let d_i = make_random_polynomial_with_root::(threshold - 1, root, rng); // Now, we need to evaluate the polynomial at each of participants' indices domain_points .iter() - .map(|x_i| { + .map(|(share_index, x_i)| { let eval = d_i.evaluate(x_i); - h.mul(eval).into_affine() + let share_update = + ferveo_tdec::PrivateKeyShare(h.mul(eval).into_affine()); + (*share_index, share_update) }) - .map(ferveo_tdec::PrivateKeyShare) - .collect() + .collect::>() } /// Generate a random polynomial with a given root @@ -288,13 +310,14 @@ mod tests_refresh { threshold: u32, x_r: &Fr, remaining_participants: &[PrivateDecryptionContextSimple], - ) -> Vec> { + ) -> HashMap> { // Each participant prepares an update for each other participant - let domain_points = remaining_participants[0] - .public_decryption_contexts + let domain_points = remaining_participants .iter() - .map(|c| c.domain) - .collect::>(); + .map(|c| { + (c.index as u32, c.public_decryption_contexts[c.index].domain) + }) + .collect::>(); let h = remaining_participants[0].public_decryption_contexts[0].h; let share_updates = remaining_participants .iter() @@ -306,25 +329,29 @@ mod tests_refresh { threshold, rng, ); - (p.index, share_updates) + (p.index as u32, share_updates) }) - .collect::>(); + .collect::>(); // Participants share updates and update their shares - let updated_private_key_shares: Vec<_> = remaining_participants + let updated_private_key_shares = remaining_participants .iter() .map(|p| { // Current participant receives updates from other participants let updates_for_participant: Vec<_> = share_updates .values() - .map(|updates| updates.get(p.index).cloned().unwrap()) + .map(|updates| { + updates.get(&(p.index as u32)).cloned().unwrap() + }) .collect(); // And updates their share - PrivateKeyShare(p.private_key_share.clone()) - .create_updated_key_share(&updates_for_participant) + let updated_share = + PrivateKeyShare(p.private_key_share.clone()) + .create_updated_key_share(&updates_for_participant); + (p.index as u32, updated_share) }) - .collect(); + .collect::>(); updated_private_key_shares } @@ -371,53 +398,64 @@ mod tests_refresh { &x_r, &remaining_participants, ); + // We only need `security_threshold` updates to recover the original share + let updated_private_key_shares = updated_private_key_shares + .into_iter() + .take(security_threshold as usize) + .collect::>(); // Now, we have to combine new share fragments into a new share - let domain_points = &remaining_participants[0] - .public_decryption_contexts - .iter() - .map(|ctxt| ctxt.domain) - .collect::>(); + let domain_points = remaining_participants + .into_iter() + .map(|ctxt| { + ( + ctxt.index as u32, + ctxt.public_decryption_contexts[ctxt.index].domain, + ) + }) + .collect::>(); let new_private_key_share = PrivateKeyShare::recover_share_from_updated_private_shares( &x_r, - &domain_points[..security_threshold as usize], - &updated_private_key_shares[..security_threshold as usize], - ); - + &domain_points, + &updated_private_key_shares, + ) + .unwrap(); assert_eq!(new_private_key_share, original_private_key_share); // If we don't have enough private share updates, the resulting private share will be incorrect - assert_eq!(domain_points.len(), updated_private_key_shares.len()); + let not_enough_shares = updated_private_key_shares + .into_iter() + .take(security_threshold as usize - 1) + .collect::>(); let incorrect_private_key_share = PrivateKeyShare::recover_share_from_updated_private_shares( &x_r, - &domain_points[..(security_threshold - 1) as usize], - &updated_private_key_shares - [..(security_threshold - 1) as usize], - ); - + &domain_points, + ¬_enough_shares, + ) + .unwrap(); assert_ne!(incorrect_private_key_share, original_private_key_share); } /// Ñ parties (where t <= Ñ <= N) jointly execute a "share recovery" algorithm, and the output is 1 new share. /// The new share is independent of the previously existing shares. We can use this to on-board a new participant into an existing cohort. - #[test_case(4, 4; "number of shares (validators) is a power of 2")] - #[test_case(7, 7; "number of shares (validators) is not a power of 2")] - fn tdec_simple_variant_share_recovery_at_random_point( - shares_num: u32, - _validators_num: u32, - ) { + #[test_case(4; "number of shares (validators) is a power of 2")] + #[test_case(7; "number of shares (validators) is not a power of 2")] + fn tdec_simple_variant_share_recovery_at_random_point(shares_num: u32) { let rng = &mut test_rng(); - let threshold = shares_num * 2 / 3; + let security_threshold = shares_num * 2 / 3; - let (_, shared_private_key, mut contexts) = - setup_simple::(threshold as usize, shares_num as usize, rng); + let (_, shared_private_key, mut contexts) = setup_simple::( + security_threshold as usize, + shares_num as usize, + rng, + ); // Prepare participants // Remove one participant from the contexts and all nested structures - contexts.pop().unwrap(); + let removed_participant = contexts.pop().unwrap(); let mut remaining_participants = contexts.clone(); for p in &mut remaining_participants { p.public_decryption_contexts.pop().unwrap(); @@ -428,52 +466,65 @@ mod tests_refresh { // Our random point let x_r = ScalarField::rand(rng); - // Each participant prepares an update for each other participant, and uses it to create a new share fragment - let share_recovery_fragmetns = create_updated_private_key_shares( + // Each remaining participant prepares an update for every other participant, and uses it to create a new share fragment + let share_recovery_updates = create_updated_private_key_shares( rng, - threshold, + security_threshold, &x_r, &remaining_participants, ); + // We only need `threshold` updates to recover the original share + let share_recovery_updates = share_recovery_updates + .into_iter() + .take(security_threshold as usize) + .collect::>(); + let domain_points = &mut remaining_participants + .into_iter() + .map(|ctxt| { + ( + ctxt.index as u32, + ctxt.public_decryption_contexts[ctxt.index].domain, + ) + }) + .collect::>(); // Now, we have to combine new share fragments into a new share - let domain_points = &mut remaining_participants[0] - .public_decryption_contexts - .iter() - .map(|ctxt| ctxt.domain) - .collect::>(); let recovered_private_key_share = PrivateKeyShare::recover_share_from_updated_private_shares( &x_r, - &domain_points[..threshold as usize], - &share_recovery_fragmetns[..threshold as usize], - ); - - let mut private_shares = contexts - .iter() - .cloned() - .map(|ctxt| ctxt.private_key_share) - .collect::>(); + domain_points, + &share_recovery_updates, + ) + .unwrap(); // Finally, let's recreate the shared private key from some original shares and the recovered one - domain_points.push(x_r); - private_shares.push(recovered_private_key_share.0.clone()); + let mut private_shares = contexts + .into_iter() + .map(|ctxt| (ctxt.index as u32, ctxt.private_key_share)) + .collect::>(); + + // Need to update these to account for recovered private key share + domain_points.insert(removed_participant.index as u32, x_r); + private_shares.insert( + removed_participant.index as u32, + recovered_private_key_share.0.clone(), + ); // This is a workaround for a type mismatch - We need to convert the private shares to updated private shares // This is just to test that we are able to recover the shared private key from the updated private shares let updated_private_key_shares = private_shares - .iter() - .cloned() - .map(UpdatedPrivateKeyShare::new) - .collect::>(); - let start_from = shares_num - threshold; + .into_iter() + .map(|(share_index, share)| { + (share_index, UpdatedPrivateKeyShare(share)) + }) + .collect::>(); let new_shared_private_key = PrivateKeyShare::recover_share_from_updated_private_shares( &ScalarField::zero(), - &domain_points[start_from as usize..], - &updated_private_key_shares[start_from as usize..], - ); - + domain_points, + &updated_private_key_shares, + ) + .unwrap(); assert_eq!(shared_private_key, new_shared_private_key.0); } @@ -483,16 +534,19 @@ mod tests_refresh { #[test_matrix([4, 7, 11, 16])] fn tdec_simple_variant_share_refreshing(shares_num: usize) { let rng = &mut test_rng(); - let threshold = shares_num * 2 / 3; + let security_threshold = shares_num * 2 / 3; let (_, private_key_share, contexts) = - setup_simple::(threshold, shares_num, rng); - - let domain_points = &contexts[0] - .public_decryption_contexts + setup_simple::(security_threshold, shares_num, rng); + let domain_points = &contexts .iter() - .map(|ctxt| ctxt.domain) - .collect::>(); + .map(|ctxt| { + ( + ctxt.index as u32, + ctxt.public_decryption_contexts[ctxt.index].domain, + ) + }) + .collect::>(); let h = contexts[0].public_decryption_contexts[0].h; // Each participant prepares an update for each other participant: @@ -503,37 +557,43 @@ mod tests_refresh { ShareRefreshUpdate::::create_share_updates( domain_points, &h, - threshold as u32, + security_threshold as u32, rng, ); - (p.index, share_updates) + (p.index as u32, share_updates) }) - .collect::>(); + .collect::>(); - // Participants "refresh" their shares with the updates from each other: - let refreshed_shares: Vec<_> = contexts + // Participants refresh their shares with the updates from each other: + let refreshed_shares = contexts .iter() .map(|p| { // Current participant receives updates from other participants let updates_for_participant: Vec<_> = share_updates .values() - .map(|updates| updates.get(p.index).cloned().unwrap()) + .map(|updates| { + updates.get(&(p.index as u32)).cloned().unwrap() + }) .collect(); // And creates a new, refreshed share - PrivateKeyShare(p.private_key_share.clone()) - .create_updated_key_share(&updates_for_participant) + let updated_share = + PrivateKeyShare(p.private_key_share.clone()) + .create_updated_key_share(&updates_for_participant); + (p.index as u32, updated_share) }) - .collect(); + // We only need `threshold` refreshed shares to recover the original share + .take(security_threshold) + .collect::>>(); // Finally, let's recreate the shared private key from the refreshed shares let new_shared_private_key = PrivateKeyShare::recover_share_from_updated_private_shares( &ScalarField::zero(), - &domain_points[..threshold], - &refreshed_shares[..threshold], - ); - + domain_points, + &refreshed_shares, + ) + .unwrap(); assert_eq!(private_key_share, new_shared_private_key.0); } } diff --git a/ferveo/src/test_common.rs b/ferveo/src/test_common.rs index df28f553..eea3a7da 100644 --- a/ferveo/src/test_common.rs +++ b/ferveo/src/test_common.rs @@ -127,6 +127,7 @@ pub fn make_messages( let sender = dkg.me.clone(); messages.push((sender, transcript)); } + messages.shuffle(rng); messages } @@ -139,7 +140,7 @@ pub fn setup_dealt_dkg_with_n_transcript_dealt( let rng = &mut ark_std::test_rng(); // Gather everyone's transcripts - // Use only the first `transcripts_to_use` transcripts + // Use only need the first `transcripts_to_use` transcripts let mut transcripts: Vec<_> = (0..transcripts_to_use) .map(|my_index| { let (dkg, _) = setup_dkg_for_n_validators( From 975dae0d5f8d1a2e5c061fbc8d11b1cc73c867d7 Mon Sep 17 00:00:00 2001 From: Piotr Roslaniec Date: Wed, 13 Mar 2024 21:37:45 +0100 Subject: [PATCH 2/2] fix: not using subset of participants in precomputed variant --- .../examples/server_api_precomputed.py | 13 +- ferveo-python/ferveo/__init__.pyi | 1 + ferveo-python/test/test_ferveo.py | 87 ++++----- ferveo-tdec/benches/tpke.rs | 14 +- ferveo-tdec/src/context.rs | 9 +- ferveo-tdec/src/lib.rs | 78 ++++---- ferveo-wasm/examples/node/src/main.test.ts | 56 +++--- ferveo-wasm/tests/node.rs | 23 ++- ferveo/src/api.rs | 172 +++++++++--------- ferveo/src/bindings_python.rs | 24 ++- ferveo/src/bindings_wasm.rs | 8 + ferveo/src/dkg.rs | 6 +- ferveo/src/lib.rs | 117 +++++++----- ferveo/src/pvss.rs | 8 +- ferveo/src/refresh.rs | 52 ++++-- ferveo/src/test_common.rs | 2 - 16 files changed, 390 insertions(+), 280 deletions(-) diff --git a/ferveo-python/examples/server_api_precomputed.py b/ferveo-python/examples/server_api_precomputed.py index 77e21e61..4ebed2ea 100644 --- a/ferveo-python/examples/server_api_precomputed.py +++ b/ferveo-python/examples/server_api_precomputed.py @@ -64,9 +64,13 @@ def gen_eth_addr(i: int) -> str: aad = "my-aad".encode() ciphertext = encrypt(msg, aad, client_aggregate.public_key) +# In precomputed variant, the client selects a subset of validators to use for decryption +selected_validators = validators[:security_threshold] +selected_keypairs = validator_keypairs[:security_threshold] + # Having aggregated the transcripts, the validators can now create decryption shares decryption_shares = [] -for validator, validator_keypair in zip(validators, validator_keypairs): +for validator, validator_keypair in zip(selected_validators, selected_keypairs): dkg = Dkg( tau=tau, shares_num=shares_num, @@ -83,13 +87,12 @@ def gen_eth_addr(i: int) -> str: # Create a decryption share for the ciphertext decryption_share = aggregate.create_decryption_share_precomputed( - dkg, ciphertext.header, aad, validator_keypair + dkg, ciphertext.header, aad, validator_keypair, selected_validators ) decryption_shares.append(decryption_share) -# We need `shares_num` decryption shares in precomputed variant -# TODO: This fails if shares_num != validators_num -decryption_shares = decryption_shares[:validators_num] +# We need at most `security_threshold` decryption shares +decryption_shares = decryption_shares[:security_threshold] # Now, the decryption share can be used to decrypt the ciphertext # This part is in the client API diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi index 69c77488..30059f4a 100644 --- a/ferveo-python/ferveo/__init__.pyi +++ b/ferveo-python/ferveo/__init__.pyi @@ -123,6 +123,7 @@ class AggregatedTranscript: ciphertext_header: CiphertextHeader, aad: bytes, validator_keypair: Keypair, + selected_validators: Sequence[Validator], ) -> DecryptionSharePrecomputed: ... @staticmethod def from_bytes(data: bytes) -> AggregatedTranscript: ... diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index 45afe800..6c8fb2a4 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -19,16 +19,6 @@ def gen_eth_addr(i: int) -> str: return f"0x{i:040x}" - -def decryption_share_for_variant(v: FerveoVariant, agg_transcript): - if v == FerveoVariant.Simple: - return agg_transcript.create_decryption_share_simple - elif v == FerveoVariant.Precomputed: - return agg_transcript.create_decryption_share_precomputed - else: - raise ValueError("Unknown variant") - - def combine_shares_for_variant(v: FerveoVariant, decryption_shares): if v == FerveoVariant.Simple: return combine_decryption_shares_simple(decryption_shares) @@ -39,7 +29,11 @@ def combine_shares_for_variant(v: FerveoVariant, decryption_shares): def scenario_for_variant( - variant: FerveoVariant, shares_num, validators_num, threshold, dec_shares_to_use + variant: FerveoVariant, + shares_num, + validators_num, + threshold, + dec_shares_to_use ): if variant not in [FerveoVariant.Simple, FerveoVariant.Precomputed]: raise ValueError("Unknown variant: " + variant) @@ -47,11 +41,8 @@ def scenario_for_variant( if validators_num < shares_num: raise ValueError("validators_num must be >= shares_num") - # TODO: Validate that - # if variant == FerveoVariant.Precomputed and dec_shares_to_use != validators_num: - # raise ValueError( - # "In precomputed variant, dec_shares_to_use must be equal to validators_num" - # ) + if shares_num < threshold: + raise ValueError("shares_num must be >= threshold") tau = 1 validator_keypairs = [Keypair.random() for _ in range(0, validators_num)] @@ -86,18 +77,27 @@ def scenario_for_variant( ) server_aggregate = dkg.aggregate_transcripts(messages) assert server_aggregate.verify(validators_num, messages) - client_aggregate = AggregatedTranscript(messages) assert client_aggregate.verify(validators_num, messages) + # At this point, DKG is done, and we are proceeding to threshold decryption + # Client creates a ciphertext and requests decryption shares from validators msg = "abc".encode() aad = "my-aad".encode() ciphertext = encrypt(msg, aad, client_aggregate.public_key) + # In precomputed variant, the client selects a subset of validators to use for decryption + if variant == FerveoVariant.Precomputed: + selected_validators = validators[:threshold] + selected_validator_keypairs = validator_keypairs[:threshold] + else: + selected_validators = validators + selected_validator_keypairs = validator_keypairs + # Having aggregated the transcripts, the validators can now create decryption shares decryption_shares = [] - for validator, validator_keypair in zip(validators, validator_keypairs): + for validator, validator_keypair in zip(selected_validators, selected_validator_keypairs): assert validator.public_key == validator_keypair.public_key() print("validator: ", validator.share_index) @@ -108,12 +108,19 @@ def scenario_for_variant( validators=validators, me=validator, ) - pvss_aggregated = dkg.aggregate_transcripts(messages) - assert pvss_aggregated.verify(validators_num, messages) - - decryption_share = decryption_share_for_variant(variant, pvss_aggregated)( - dkg, ciphertext.header, aad, validator_keypair - ) + server_aggregate = dkg.aggregate_transcripts(messages) + assert server_aggregate.verify(validators_num, messages) + + if variant == FerveoVariant.Simple: + decryption_share = server_aggregate.create_decryption_share_simple( + dkg, ciphertext.header, aad, validator_keypair + ) + elif variant == FerveoVariant.Precomputed: + decryption_share = server_aggregate.create_decryption_share_precomputed( + dkg, ciphertext.header, aad, validator_keypair, selected_validators + ) + else: + raise ValueError("Unknown variant") decryption_shares.append(decryption_share) # We are limiting the number of decryption shares to use for testing purposes @@ -122,12 +129,7 @@ def scenario_for_variant( # Client combines the decryption shares and decrypts the ciphertext shared_secret = combine_shares_for_variant(variant, decryption_shares) - if variant == FerveoVariant.Simple and len(decryption_shares) < threshold: - with pytest.raises(ThresholdEncryptionError): - decrypt_with_shared_secret(ciphertext, aad, shared_secret) - return - - if variant == FerveoVariant.Precomputed and len(decryption_shares) < threshold: + if len(decryption_shares) < threshold: with pytest.raises(ThresholdEncryptionError): decrypt_with_shared_secret(ciphertext, aad, shared_secret) return @@ -137,8 +139,8 @@ def scenario_for_variant( def test_simple_tdec_has_enough_messages(): - shares_num = 4 - threshold = shares_num - 1 + shares_num = 8 + threshold = int(shares_num * 2 / 3) for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Simple, @@ -150,41 +152,44 @@ def test_simple_tdec_has_enough_messages(): def test_simple_tdec_doesnt_have_enough_messages(): - shares_num = 4 - threshold = shares_num - 1 + shares_num = 8 + threshold = int(shares_num * 2 / 3) + dec_shares_to_use = threshold - 1 for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Simple, shares_num=shares_num, validators_num=validators_num, threshold=threshold, - dec_shares_to_use=validators_num - 1, + dec_shares_to_use=dec_shares_to_use, ) def test_precomputed_tdec_has_enough_messages(): - shares_num = 4 - threshold = shares_num # in precomputed variant, we need all shares + shares_num = 8 + threshold = int(shares_num * 2 / 3) + dec_shares_to_use = threshold for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Precomputed, shares_num=shares_num, validators_num=validators_num, threshold=threshold, - dec_shares_to_use=validators_num, + dec_shares_to_use=dec_shares_to_use, ) def test_precomputed_tdec_doesnt_have_enough_messages(): - shares_num = 4 - threshold = shares_num # in precomputed variant, we need all shares + shares_num = 8 + threshold = int(shares_num * 2 / 3) + dec_shares_to_use = threshold - 1 for validators_num in [shares_num, shares_num + 2]: scenario_for_variant( FerveoVariant.Simple, shares_num=shares_num, validators_num=validators_num, threshold=threshold, - dec_shares_to_use=threshold - 1, + dec_shares_to_use=dec_shares_to_use, ) diff --git a/ferveo-tdec/benches/tpke.rs b/ferveo-tdec/benches/tpke.rs index b7a5b8f7..03711289 100644 --- a/ferveo-tdec/benches/tpke.rs +++ b/ferveo-tdec/benches/tpke.rs @@ -105,7 +105,7 @@ impl SetupSimple { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, privkey, contexts) = - setup_simple::(threshold, shares_num, rng); + setup_simple::(shares_num, threshold, rng); // Ciphertext.commitment is already computed to match U let ciphertext = @@ -124,10 +124,10 @@ impl SetupSimple { let pub_contexts = contexts[0].clone().public_decryption_contexts; let domain: Vec = pub_contexts.iter().map(|c| c.domain).collect(); - let lagrange = prepare_combine_simple::(&domain); + let lagrange_coeffs = prepare_combine_simple::(&domain); let shared_secret = - share_combine_simple::(&decryption_shares, &lagrange); + share_combine_simple::(&decryption_shares, &lagrange_coeffs); let shared = SetupShared { threshold, @@ -144,7 +144,7 @@ impl SetupSimple { contexts, pub_contexts, decryption_shares, - lagrange_coeffs: lagrange, + lagrange_coeffs, } } } @@ -200,6 +200,8 @@ pub fn bench_create_decryption_share(c: &mut Criterion) { }; let simple_precomputed = { let setup = SetupSimple::new(shares_num, MSG_SIZE_CASES[0], rng); + let selected_participants = + (0..setup.shared.threshold).collect::>(); move || { black_box( setup @@ -209,6 +211,7 @@ pub fn bench_create_decryption_share(c: &mut Criterion) { context.create_share_precomputed( &setup.shared.ciphertext.header().unwrap(), &setup.shared.aad, + &selected_participants, ) }) .collect::>(), @@ -295,6 +298,8 @@ pub fn bench_share_combine(c: &mut Criterion) { }; let simple_precomputed = { let setup = SetupSimple::new(shares_num, MSG_SIZE_CASES[0], rng); + // TODO: Use threshold instead of shares_num + let selected_participants = (0..shares_num).collect::>(); let decryption_shares: Vec<_> = setup .contexts @@ -304,6 +309,7 @@ pub fn bench_share_combine(c: &mut Criterion) { .create_share_precomputed( &setup.shared.ciphertext.header().unwrap(), &setup.shared.aad, + &selected_participants, ) .unwrap() }) diff --git a/ferveo-tdec/src/context.rs b/ferveo-tdec/src/context.rs index ed7faee0..ba697917 100644 --- a/ferveo-tdec/src/context.rs +++ b/ferveo-tdec/src/context.rs @@ -92,13 +92,14 @@ impl PrivateDecryptionContextSimple { &self, ciphertext_header: &CiphertextHeader, aad: &[u8], + selected_participants: &[usize], ) -> Result> { - let domain = self - .public_decryption_contexts + let selected_domain_points = selected_participants .iter() - .map(|c| c.domain) + .map(|i| self.public_decryption_contexts[*i].domain) .collect::>(); - let lagrange_coeffs = prepare_combine_simple::(&domain); + let lagrange_coeffs = + prepare_combine_simple::(&selected_domain_points); DecryptionSharePrecomputed::create( self.index, diff --git a/ferveo-tdec/src/lib.rs b/ferveo-tdec/src/lib.rs index e491bba7..e0086dbf 100644 --- a/ferveo-tdec/src/lib.rs +++ b/ferveo-tdec/src/lib.rs @@ -175,8 +175,8 @@ pub mod test_common { } pub fn setup_simple( - threshold: usize, shares_num: usize, + threshold: usize, rng: &mut impl rand::Rng, ) -> ( PublicKey, @@ -264,17 +264,17 @@ pub mod test_common { pub fn setup_precomputed( shares_num: usize, + threshold: usize, rng: &mut impl rand::Rng, ) -> ( PublicKey, PrivateKeyShare, Vec>, ) { - // In precomputed variant, the security threshold is equal to the number of shares - setup_simple::(shares_num, shares_num, rng) + setup_simple::(shares_num, threshold, rng) } - pub fn create_shared_secret( + pub fn create_shared_secret_simple( pub_contexts: &[PublicDecryptionContextSimple], decryption_shares: &[DecryptionShareSimple], ) -> SharedSecret { @@ -291,8 +291,12 @@ mod tests { use ark_ec::{pairing::Pairing, AffineRepr, CurveGroup}; use ark_std::{test_rng, UniformRand}; use ferveo_common::{FromBytes, ToBytes}; + use rand::seq::IteratorRandom; - use crate::test_common::{create_shared_secret, setup_simple, *}; + use crate::{ + api::DecryptionSharePrecomputed, + test_common::{create_shared_secret_simple, setup_simple, *}, + }; type E = ark_bls12_381::Bls12_381; type TargetField = ::TargetField; @@ -378,7 +382,7 @@ mod tests { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_simple::(threshold, shares_num, rng); + setup_simple::(shares_num, threshold, rng); let ciphertext = encrypt::(SecretBox::new(msg), aad, &pubkey, rng).unwrap(); @@ -447,7 +451,7 @@ mod tests { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_simple::(threshold, shares_num, &mut rng); + setup_simple::(shares_num, threshold, &mut rng); let g_inv = &contexts[0].setup_params.g_inv; let ciphertext = @@ -462,10 +466,10 @@ mod tests { }) .take(threshold) .collect(); - let pub_contexts = + let selected_contexts = contexts[0].public_decryption_contexts[..threshold].to_vec(); let shared_secret = - create_shared_secret(&pub_contexts, &decryption_shares); + create_shared_secret_simple(&selected_contexts, &decryption_shares); test_ciphertext_validation_fails( &msg, @@ -476,13 +480,18 @@ mod tests { ); // If we use less than threshold shares, we should fail - let decryption_shares = decryption_shares[..threshold - 1].to_vec(); - let pub_contexts = pub_contexts[..threshold - 1].to_vec(); - let shared_secret = - create_shared_secret(&pub_contexts, &decryption_shares); - - let result = - decrypt_with_shared_secret(&ciphertext, aad, &shared_secret, g_inv); + let not_enough_dec_shares = decryption_shares[..threshold - 1].to_vec(); + let not_enough_contexts = selected_contexts[..threshold - 1].to_vec(); + let bash_shared_secret = create_shared_secret_simple( + ¬_enough_contexts, + ¬_enough_dec_shares, + ); + let result = decrypt_with_shared_secret( + &ciphertext, + aad, + &bash_shared_secret, + g_inv, + ); assert!(result.is_err()); } @@ -490,30 +499,39 @@ mod tests { fn tdec_precomputed_variant_e2e() { let mut rng = &mut test_rng(); let shares_num = 16; + let threshold = shares_num * 2 / 3; let msg = "my-msg".as_bytes().to_vec(); let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_precomputed::(shares_num, &mut rng); + setup_precomputed::(shares_num, threshold, &mut rng); let g_inv = &contexts[0].setup_params.g_inv; let ciphertext = encrypt::(SecretBox::new(msg.clone()), aad, &pubkey, rng) .unwrap(); - let decryption_shares: Vec<_> = contexts + let selected_participants = + (0..threshold).choose_multiple(rng, threshold); + let selected_contexts = contexts + .iter() + .filter(|c| selected_participants.contains(&c.index)) + .cloned() + .collect::>(); + + let decryption_shares = selected_contexts .iter() .map(|context| { context .create_share_precomputed( &ciphertext.header().unwrap(), aad, + &selected_participants, ) .unwrap() }) - .collect(); + .collect::>(); let shared_secret = share_combine_precomputed::(&decryption_shares); - test_ciphertext_validation_fails( &msg, aad, @@ -522,19 +540,17 @@ mod tests { g_inv, ); - // Note that in this variant, if we use less than `share_num` shares, we will get a - // decryption error. - - let not_enough_shares = &decryption_shares[0..shares_num - 1]; - let bad_shared_secret = - share_combine_precomputed::(not_enough_shares); - assert!(decrypt_with_shared_secret( + // If we use less than threshold shares, we should fail + let not_enough_dec_shares = decryption_shares[..threshold - 1].to_vec(); + let bash_shared_secret = + share_combine_precomputed(¬_enough_dec_shares); + let result = decrypt_with_shared_secret( &ciphertext, aad, - &bad_shared_secret, + &bash_shared_secret, g_inv, - ) - .is_err()); + ); + assert!(result.is_err()); } #[test] @@ -546,7 +562,7 @@ mod tests { let aad: &[u8] = "my-aad".as_bytes(); let (pubkey, _, contexts) = - setup_simple::(threshold, shares_num, &mut rng); + setup_simple::(shares_num, threshold, &mut rng); let ciphertext = encrypt::(SecretBox::new(msg), aad, &pubkey, rng).unwrap(); diff --git a/ferveo-wasm/examples/node/src/main.test.ts b/ferveo-wasm/examples/node/src/main.test.ts index 57071421..ba5361c3 100644 --- a/ferveo-wasm/examples/node/src/main.test.ts +++ b/ferveo-wasm/examples/node/src/main.test.ts @@ -52,10 +52,9 @@ function setupTest( // every validator can aggregate the transcripts const dkg = new Dkg(TAU, sharesNum, threshold, validators, validators[0]); + // Both the server and the client can aggregate the transcripts and verify them const serverAggregate = dkg.aggregateTranscript(messages); expect(serverAggregate.verify(validatorsNum, messages)).toBe(true); - - // Client can also aggregate the transcripts and verify them const clientAggregate = new AggregatedTranscript(messages); expect(clientAggregate.verify(validatorsNum, messages)).toBe(true); @@ -81,8 +80,14 @@ describe("ferveo-wasm", () => { const sharesNum = 4; const threshold = sharesNum - 1; [sharesNum, sharesNum + 2].forEach((validatorsNum) => { - const { validatorKeypairs, validators, messages, msg, aad, ciphertext } = - setupTest(sharesNum, validatorsNum, threshold); + const { + validatorKeypairs, + validators, + messages, + msg, + aad, + ciphertext + } = setupTest(sharesNum, validatorsNum, threshold); // Having aggregated the transcripts, the validators can now create decryption shares const decryptionShares: DecryptionShareSimple[] = []; @@ -90,11 +95,11 @@ describe("ferveo-wasm", () => { expect(validator.publicKey.equals(keypair.publicKey)).toBe(true); const dkg = new Dkg(TAU, sharesNum, threshold, validators, validator); - const aggregate = dkg.aggregateTranscript(messages); - const isValid = aggregate.verify(validatorsNum, messages); + const serverAggregate = dkg.aggregateTranscript(messages); + const isValid = serverAggregate.verify(validatorsNum, messages); expect(isValid).toBe(true); - const decryptionShare = aggregate.createDecryptionShareSimple( + const decryptionShare = serverAggregate.createDecryptionShareSimple( dkg, ciphertext.header, aad, @@ -105,49 +110,52 @@ describe("ferveo-wasm", () => { // Now, the decryption share can be used to decrypt the ciphertext // This part is in the client API - const sharedSecret = combineDecryptionSharesSimple(decryptionShares); - - // The client should have access to the public parameters of the DKG - const plaintext = decryptWithSharedSecret(ciphertext, aad, sharedSecret); expect(Buffer.from(plaintext)).toEqual(msg); }); }); it("precomputed tdec variant", () => { - const sharesNum = 4; - const threshold = sharesNum; // threshold is equal to sharesNum in precomputed variant + const sharesNum = 8; + const threshold = sharesNum * 2 / 3; [sharesNum, sharesNum + 2].forEach((validatorsNum) => { - const { validatorKeypairs, validators, messages, msg, aad, ciphertext } = - setupTest(sharesNum, validatorsNum, threshold); + const { + validatorKeypairs, + validators, + messages, + msg, + aad, + ciphertext + } = setupTest(sharesNum, validatorsNum, threshold); + + // In precomputed variant, client selects a subset of validators to create decryption shares + const selectedValidators = validators.slice(0, threshold); + const selectedValidatorKeypairs = validatorKeypairs.slice(0, threshold); // Having aggregated the transcripts, the validators can now create decryption shares const decryptionShares: DecryptionSharePrecomputed[] = []; - zip(validators, validatorKeypairs).forEach(([validator, keypair]) => { + zip(selectedValidators, selectedValidatorKeypairs).forEach(([validator, keypair]) => { expect(validator.publicKey.equals(keypair.publicKey)).toBe(true); const dkg = new Dkg(TAU, sharesNum, threshold, validators, validator); - const aggregate = dkg.aggregateTranscript(messages); - const isValid = aggregate.verify(validatorsNum, messages); + const serverAggregate = dkg.aggregateTranscript(messages); + const isValid = serverAggregate.verify(validatorsNum, messages); expect(isValid).toBe(true); - const decryptionShare = aggregate.createDecryptionSharePrecomputed( + const decryptionShare = serverAggregate.createDecryptionSharePrecomputed( dkg, ciphertext.header, aad, - keypair + keypair, + selectedValidators, ); decryptionShares.push(decryptionShare); }); // Now, the decryption share can be used to decrypt the ciphertext // This part is in the client API - const sharedSecret = combineDecryptionSharesPrecomputed(decryptionShares); - - // The client should have access to the public parameters of the DKG - const plaintext = decryptWithSharedSecret(ciphertext, aad, sharedSecret); expect(Buffer.from(plaintext)).toEqual(msg); }); diff --git a/ferveo-wasm/tests/node.rs b/ferveo-wasm/tests/node.rs index d3a5ea43..cddde011 100644 --- a/ferveo-wasm/tests/node.rs +++ b/ferveo-wasm/tests/node.rs @@ -155,7 +155,7 @@ fn tdec_simple() { #[wasm_bindgen_test] fn tdec_precomputed() { let shares_num = 16; - let security_threshold = shares_num; // Must be equal to shares_num in precomputed variant + let security_threshold = shares_num * 2 / 3; for validators_num in [shares_num, shares_num + 2] { let ( validator_keypairs, @@ -167,6 +167,11 @@ fn tdec_precomputed() { ciphertext, ) = setup_dkg(shares_num, validators_num, security_threshold); + // In precomputed variant, the client selects a subset of validators to create decryption shares + let selected_validators = + validators[..(security_threshold as usize)].to_vec(); + let selected_validators_js = into_js_array(selected_validators); + // Having aggregated the transcripts, the validators can now create decryption shares let decryption_shares = zip_eq(validators, validator_keypairs) .map(|(validator, keypair)| { @@ -178,23 +183,23 @@ fn tdec_precomputed() { &validator, ) .unwrap(); - let aggregate = + let server_aggregate = dkg.aggregate_transcripts(&messages_js).unwrap(); - let is_valid = - aggregate.verify(validators_num, &messages_js).unwrap(); - assert!(is_valid); - aggregate + assert!(server_aggregate + .verify(validators_num, &messages_js) + .unwrap()); + server_aggregate .create_decryption_share_precomputed( &dkg, &ciphertext.header().unwrap(), &aad, &keypair, + &selected_validators_js, ) .unwrap() }) - // We need `shares_num` decryption shares in precomputed variant - // TODO: This fails if shares_num != validators_num - .take(validators_num as usize) + // We need `security_threshold` decryption shares to decrypt + .take(security_threshold as usize) .collect::>(); let decryption_shares_js = into_js_array(decryption_shares); diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index 7ab62c56..637750cc 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -308,22 +308,23 @@ impl AggregatedTranscript { ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, + selected_validators: &[Validator], ) -> Result { - // Prevent users from using the precomputed variant with improper DKG parameters - if dkg.0.dkg_params.shares_num() - != dkg.0.dkg_params.security_threshold() - { - return Err(Error::InvalidDkgParametersForPrecomputedVariant( - dkg.0.dkg_params.shares_num(), - dkg.0.dkg_params.security_threshold(), - )); - } - self.0.aggregate.create_decryption_share_simple_precomputed( + let selected_domain_points = selected_validators + .iter() + .filter_map(|v| { + dkg.0 + .get_domain_point(v.share_index) + .ok() + .map(|domain_point| (v.share_index, domain_point)) + }) + .collect::>>(); + self.0.aggregate.create_decryption_share_precomputed( &ciphertext_header.0, aad, validator_keypair, dkg.0.me.share_index, - &dkg.0.domain_points(), + &selected_domain_points, ) } @@ -544,22 +545,21 @@ impl PrivateKeyShare { } /// Make a decryption share (precomputed variant) for a given ciphertext - pub fn create_decryption_share_simple_precomputed( + pub fn create_decryption_share_precomputed( &self, ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points: &HashMap, ) -> Result { - let share = self.0.create_decryption_share_simple_precomputed( + self.0.create_decryption_share_precomputed( &ciphertext_header.0, aad, validator_keypair, share_index, domain_points, - )?; - Ok(share) + ) } pub fn to_bytes(&self) -> Result> { @@ -641,10 +641,8 @@ mod test_ferveo_api { #[test_case(7, 7; "number of shares (validators) is not a power of 2")] #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_server_api_tdec_precomputed(shares_num: u32, validators_num: u32) { + let security_threshold = shares_num * 2 / 3; let rng = &mut StdRng::seed_from_u64(0); - - // In precomputed variant, the security threshold is equal to the number of shares - let security_threshold = shares_num; let (messages, validators, validator_keypairs) = make_test_inputs( rng, TAU, @@ -660,52 +658,65 @@ mod test_ferveo_api { let dkg = Dkg::new(TAU, shares_num, security_threshold, &validators, &me) .unwrap(); - let pvss_aggregated = dkg.aggregate_transcripts(messages).unwrap(); - assert!(pvss_aggregated.verify(validators_num, messages).unwrap()); + let local_aggregate = dkg.aggregate_transcripts(messages).unwrap(); + assert!(local_aggregate.verify(validators_num, messages).unwrap()); // At this point, any given validator should be able to provide a DKG public key - let dkg_public_key = pvss_aggregated.public_key(); + let dkg_public_key = local_aggregate.public_key(); // In the meantime, the client creates a ciphertext and decryption request let ciphertext = encrypt(SecretBox::new(MSG.to_vec()), AAD, &dkg_public_key) .unwrap(); + // In precomputed variant, client selects a specific subset of validators to create + // decryption shares + let selected_validators: Vec<_> = validators + .choose_multiple(rng, security_threshold as usize) + .cloned() + .collect(); + // Having aggregated the transcripts, the validators can now create decryption shares - let mut decryption_shares: Vec<_> = - izip!(&validators, &validator_keypairs) - .map(|(validator, validator_keypair)| { - // Each validator holds their own instance of DKG and creates their own aggregate - let dkg = Dkg::new( - TAU, - shares_num, - security_threshold, - &validators, - validator, - ) + let mut decryption_shares = selected_validators + .iter() + .map(|validator| { + let validator_keypair = validator_keypairs + .iter() + .find(|kp| kp.public_key() == validator.public_key) .unwrap(); - let aggregate = - dkg.aggregate_transcripts(messages).unwrap(); - assert!(pvss_aggregated - .verify(validators_num, messages) - .unwrap()); - - // And then each validator creates their own decryption share - aggregate - .create_decryption_share_precomputed( - &dkg, - &ciphertext.header().unwrap(), - AAD, - validator_keypair, - ) - .unwrap() - }) - .collect(); + // Each validator holds their own instance of DKG and creates their own aggregate + let dkg = Dkg::new( + TAU, + shares_num, + security_threshold, + &validators, + validator, + ) + .unwrap(); + let server_aggregate = + dkg.aggregate_transcripts(messages).unwrap(); + assert!(server_aggregate + .verify(validators_num, messages) + .unwrap()); + + // And then each validator creates their own decryption share + server_aggregate + .create_decryption_share_precomputed( + &dkg, + &ciphertext.header().unwrap(), + AAD, + validator_keypair, + &selected_validators, + ) + .unwrap() + }) + // We only need `security_threshold` shares to be able to decrypt + .take(security_threshold as usize) + .collect::>(); decryption_shares.shuffle(rng); // Now, the decryption share can be used to decrypt the ciphertext // This part is part of the client API - let shared_secret = share_combine_precomputed(&decryption_shares); let plaintext = decrypt_with_shared_secret( &ciphertext, @@ -715,10 +726,13 @@ mod test_ferveo_api { .unwrap(); assert_eq!(plaintext, MSG); - // Since we're using a precomputed variant, we need all the shares to be able to decrypt + // We need `security_threshold` shares to be able to decrypt // So if we remove one share, we should not be able to decrypt - let decryption_shares = - decryption_shares[..shares_num as usize - 1].to_vec(); + let decryption_shares = decryption_shares + .iter() + .take(security_threshold as usize - 1) + .cloned() + .collect::>(); let shared_secret = share_combine_precomputed(&decryption_shares); let result = decrypt_with_shared_secret( &ciphertext, @@ -733,7 +747,6 @@ mod test_ferveo_api { #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_server_api_tdec_simple(shares_num: u32, validators_num: u32) { let rng = &mut StdRng::seed_from_u64(0); - let security_threshold = shares_num / 2 + 1; let (messages, validators, validator_keypairs) = make_test_inputs( rng, @@ -747,19 +760,11 @@ mod test_ferveo_api { // Now that every validator holds a dkg instance and a transcript for every other validator, // every validator can aggregate the transcripts - let dkg = Dkg::new( - TAU, - shares_num, - security_threshold, - &validators, - &validators[0], - ) - .unwrap(); - let pvss_aggregated = dkg.aggregate_transcripts(messages).unwrap(); - assert!(pvss_aggregated.verify(validators_num, messages).unwrap()); + let local_aggregate = AggregatedTranscript::new(messages).unwrap(); + assert!(local_aggregate.verify(validators_num, messages).unwrap()); // At this point, any given validator should be able to provide a DKG public key - let public_key = pvss_aggregated.public_key(); + let public_key = local_aggregate.public_key(); // In the meantime, the client creates a ciphertext and decryption request let ciphertext = @@ -778,12 +783,12 @@ mod test_ferveo_api { validator, ) .unwrap(); - let aggregate = + let server_aggregate = dkg.aggregate_transcripts(messages).unwrap(); - assert!(aggregate + assert!(server_aggregate .verify(validators_num, messages) .unwrap()); - aggregate + server_aggregate .create_decryption_share_simple( &dkg, &ciphertext.header().unwrap(), @@ -792,13 +797,13 @@ mod test_ferveo_api { ) .unwrap() }) + // We only need `security_threshold` shares to be able to decrypt + .take(security_threshold as usize) .collect(); decryption_shares.shuffle(rng); // Now, the decryption share can be used to decrypt the ciphertext // This part is part of the client API - - // In simple variant, we only need `security_threshold` shares to be able to decrypt let decryption_shares = decryption_shares[..security_threshold as usize].to_vec(); @@ -808,8 +813,8 @@ mod test_ferveo_api { .unwrap(); assert_eq!(plaintext, MSG); - // Let's say that we've only received `security_threshold - 1` shares - // In this case, we should not be able to decrypt + // We need `security_threshold` shares to be able to decrypt + // So if we remove one share, we should not be able to decrypt let decryption_shares = decryption_shares[..security_threshold as usize - 1].to_vec(); @@ -819,9 +824,9 @@ mod test_ferveo_api { assert!(result.is_err()); } - // Note that the server and client code are using the same underlying - // implementation for aggregation and aggregate verification. - // Here, we focus on testing user-facing APIs for server and client users. + /// Note that the server and client code are using the same underlying + /// implementation for aggregation and aggregate verification. + /// Here, we focus on testing user-facing APIs for server and client users. #[test_case(4, 4; "number of shares (validators) is a power of 2")] #[test_case(7, 7; "number of shares (validators) is not a power of 2")] @@ -829,7 +834,6 @@ mod test_ferveo_api { fn server_side_local_verification(shares_num: u32, validators_num: u32) { let rng = &mut StdRng::seed_from_u64(0); let security_threshold = shares_num / 2 + 1; - let (messages, validators, _) = make_test_inputs( rng, TAU, @@ -940,7 +944,6 @@ mod test_ferveo_api { fn client_side_local_verification(shares_num: u32, validators_num: u32) { let rng = &mut StdRng::seed_from_u64(0); let security_threshold = shares_num / 2 + 1; - let (messages, _, _) = make_test_inputs( rng, TAU, @@ -1041,11 +1044,11 @@ mod test_ferveo_api { // Creating a copy to avoiding accidentally changing DKG state let dkg = dkgs[0].clone(); - let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated.verify(validators_num, &messages).unwrap()); + let server_aggregate = dkg.aggregate_transcripts(&messages).unwrap(); + assert!(server_aggregate.verify(validators_num, &messages).unwrap()); // Create an initial shared secret for testing purposes - let public_key = pvss_aggregated.public_key(); + let public_key = server_aggregate.public_key(); let ciphertext = encrypt(SecretBox::new(MSG.to_vec()), AAD, &public_key).unwrap(); let ciphertext_header = ciphertext.header().unwrap(); @@ -1062,7 +1065,6 @@ mod test_ferveo_api { validator_keypairs.as_slice(), &transcripts, ); - ( messages, validators, @@ -1084,7 +1086,6 @@ mod test_ferveo_api { ) { let rng = &mut StdRng::seed_from_u64(0); let security_threshold = shares_num / 2 + 1; - let ( mut messages, mut validators, @@ -1240,8 +1241,6 @@ mod test_ferveo_api { .unwrap(); decryption_shares.push(new_decryption_share); domain_points.insert(new_validator_share_index, x_r); - assert_eq!(domain_points.len(), validators_num as usize); - assert_eq!(decryption_shares.len(), validators_num as usize); let domain_points = domain_points .values() @@ -1269,7 +1268,6 @@ mod test_ferveo_api { ) { let rng = &mut StdRng::seed_from_u64(0); let security_threshold = shares_num / 2 + 1; - let ( messages, _validators, @@ -1355,6 +1353,8 @@ mod test_ferveo_api { ) .unwrap() }) + // We only need `security_threshold` shares to be able to decrypt + .take(security_threshold as usize) .collect(); decryption_shares.shuffle(rng); diff --git a/ferveo/src/bindings_python.rs b/ferveo/src/bindings_python.rs index 0689a495..bd189e7e 100644 --- a/ferveo/src/bindings_python.rs +++ b/ferveo/src/bindings_python.rs @@ -621,7 +621,10 @@ impl AggregatedTranscript { ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, + selected_validators: Vec, ) -> PyResult { + let selected_validators: Vec<_> = + selected_validators.into_iter().map(|v| v.0).collect(); let decryption_share = self .0 .create_decryption_share_precomputed( @@ -629,6 +632,7 @@ impl AggregatedTranscript { &ciphertext_header.0, aad, &validator_keypair.0, + &selected_validators, ) .map_err(FerveoPythonError::FerveoError)?; Ok(DecryptionSharePrecomputed(decryption_share)) @@ -841,9 +845,7 @@ mod test_ferveo_python { #[test_case(4, 4; "number of validators equal to the number of shares")] #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_server_api_tdec_precomputed(shares_num: u32, validators_num: u32) { - // In precomputed variant, the security threshold is equal to the number of shares - let security_threshold = shares_num; - + let security_threshold = shares_num * 2 / 3; let (messages, validators, validator_keypairs) = make_test_inputs( TAU, security_threshold, @@ -866,18 +868,21 @@ mod test_ferveo_python { // Let's say that we've only received `security_threshold` transcripts let messages = messages[..security_threshold as usize].to_vec(); - let pvss_aggregated = + let local_aggregate = dkg.aggregate_transcripts(messages.clone()).unwrap(); - assert!(pvss_aggregated + assert!(local_aggregate .verify(validators_num, messages.clone()) .unwrap()); // At this point, any given validator should be able to provide a DKG public key - let dkg_public_key = pvss_aggregated.public_key(); + let dkg_public_key = local_aggregate.public_key(); // In the meantime, the client creates a ciphertext and decryption request let ciphertext = encrypt(MSG.to_vec(), AAD, &dkg_public_key).unwrap(); + // TODO: Adjust the subset of validators to be used in the decryption for precomputed + // variant + // Having aggregated the transcripts, the validators can now create decryption shares let decryption_shares: Vec<_> = izip!(validators.clone(), &validator_keypairs) @@ -891,18 +896,19 @@ mod test_ferveo_python { &validator, ) .unwrap(); - let aggregate = validator_dkg + let server_aggregate = validator_dkg .aggregate_transcripts(messages.clone()) .unwrap(); - assert!(pvss_aggregated + assert!(server_aggregate .verify(validators_num, messages.clone()) .is_ok()); - aggregate + server_aggregate .create_decryption_share_precomputed( &validator_dkg, &ciphertext.header().unwrap(), AAD, validator_keypair, + validators.clone(), ) .unwrap() }) diff --git a/ferveo/src/bindings_wasm.rs b/ferveo/src/bindings_wasm.rs index 56325092..0b369874 100644 --- a/ferveo/src/bindings_wasm.rs +++ b/ferveo/src/bindings_wasm.rs @@ -536,8 +536,15 @@ impl AggregatedTranscript { ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, + selected_validators_js: &ValidatorArray, ) -> JsResult { set_panic_hook(); + let selected_validators = + try_from_js_array::(selected_validators_js)?; + let selected_validators = selected_validators + .into_iter() + .map(|v| v.to_inner()) + .collect::>>()?; let decryption_share = self .0 .create_decryption_share_precomputed( @@ -545,6 +552,7 @@ impl AggregatedTranscript { &ciphertext_header.0, aad, &validator_keypair.0, + &selected_validators, ) .map_err(map_js_err)?; Ok(DecryptionSharePrecomputed(decryption_share)) diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index d2e825a7..360e5202 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -169,10 +169,10 @@ impl PubliclyVerifiableDkg { /// Return a map of domain points for the DKG pub fn domain_point_map(&self) -> HashMap> { - self.domain_points() - .iter() + self.domain + .elements() .enumerate() - .map(|(i, point)| (i as u32, *point)) + .map(|(i, point)| (i as u32, point)) .collect::>() } diff --git a/ferveo/src/lib.rs b/ferveo/src/lib.rs index 67e501fe..e3638fe6 100644 --- a/ferveo/src/lib.rs +++ b/ferveo/src/lib.rs @@ -138,9 +138,9 @@ mod test_dkg_full { Vec>, SharedSecret, ) { - let pvss_aggregated = + let server_aggregate = AggregatedTranscript::from_transcripts(transcripts).unwrap(); - assert!(pvss_aggregated + assert!(server_aggregate .aggregate .verify_aggregation(dkg, transcripts) .unwrap()); @@ -152,7 +152,7 @@ mod test_dkg_full { let validator = dkg .get_validator(&validator_keypair.public_key()) .unwrap(); - pvss_aggregated + server_aggregate .aggregate .create_decryption_share_simple( ciphertext_header, @@ -175,7 +175,7 @@ mod test_dkg_full { &decryption_shares, &lagrange_coeffs, ); - (pvss_aggregated, decryption_shares, shared_secret) + (server_aggregate, decryption_shares, shared_secret) } #[test_case(4, 4; "number of shares (validators) is a power of 2")] @@ -183,8 +183,7 @@ mod test_dkg_full { #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_dkg_simple_tdec(shares_num: u32, validators_num: u32) { let rng = &mut test_rng(); - - let security_threshold = shares_num / 2 + 1; + let security_threshold = shares_num * 2 / 3; let (dkg, validator_keypairs, messages) = setup_dealt_dkg_with_n_validators( security_threshold, @@ -196,17 +195,19 @@ mod test_dkg_full { .take(shares_num as usize) .map(|m| m.1.clone()) .collect::>(); - let public_key = AggregatedTranscript::from_transcripts(&transcripts) - .unwrap() - .public_key; + let local_aggregate = + AggregatedTranscript::from_transcripts(&transcripts).unwrap(); + assert!(local_aggregate + .aggregate + .verify_aggregation(&dkg, &transcripts) + .unwrap()); let ciphertext = ferveo_tdec::encrypt::( SecretBox::new(MSG.to_vec()), AAD, - &public_key, + &local_aggregate.public_key, rng, ) .unwrap(); - let (_, _, shared_secret) = create_shared_secret_simple_tdec( &dkg, AAD, @@ -227,62 +228,77 @@ mod test_dkg_full { #[test_case(4, 4; "number of shares (validators) is a power of 2")] #[test_case(7, 7; "number of shares (validators) is not a power of 2")] - // TODO: This test fails: - // #[test_case(4, 6; "number of validators greater than the number of shares")] + #[test_case(4, 6; "number of validators greater than the number of shares")] fn test_dkg_simple_tdec_precomputed(shares_num: u32, validators_num: u32) { let rng = &mut test_rng(); - - // In precomputed variant, threshold must be equal to shares_num - let security_threshold = shares_num; + let security_threshold = shares_num * 2 / 3; let (dkg, validator_keypairs, messages) = - setup_dealt_dkg_with_n_validators( + setup_dealt_dkg_with_n_transcript_dealt( security_threshold, shares_num, validators_num, + shares_num, ); let transcripts = messages .iter() .take(shares_num as usize) .map(|m| m.1.clone()) .collect::>(); - let pvss_aggregated = + let local_aggregate = AggregatedTranscript::from_transcripts(&transcripts).unwrap(); - assert!(pvss_aggregated + assert!(local_aggregate .aggregate .verify_aggregation(&dkg, &transcripts) .unwrap()); - let public_key = pvss_aggregated.public_key; let ciphertext = ferveo_tdec::encrypt::( SecretBox::new(MSG.to_vec()), AAD, - &public_key, + &local_aggregate.public_key, rng, ) .unwrap(); + // In precomputed variant, client selects a specific subset of validators to create + // decryption shares + let selected_keypairs = validator_keypairs + .choose_multiple(rng, security_threshold as usize) + .collect::>(); + let selected_validators = selected_keypairs + .iter() + .map(|keypair| { + dkg.get_validator(&keypair.public_key()) + .expect("Validator not found") + }) + .collect::>(); + let selected_domain_points = selected_validators + .iter() + .filter_map(|v| { + dkg.get_domain_point(v.share_index) + .ok() + .map(|domain_point| (v.share_index, domain_point)) + }) + .collect::>>(); + let mut decryption_shares: Vec> = - validator_keypairs + selected_keypairs .iter() .map(|validator_keypair| { let validator = dkg .get_validator(&validator_keypair.public_key()) .unwrap(); - pvss_aggregated + local_aggregate .aggregate - .create_decryption_share_simple_precomputed( + .create_decryption_share_precomputed( &ciphertext.header().unwrap(), AAD, validator_keypair, validator.share_index, - &dkg.domain_points(), + &selected_domain_points, ) .unwrap() }) - // We take only the first `security_threshold` decryption shares - .take(dkg.dkg_params.security_threshold() as usize) .collect(); - - // Order of decryption shares is not important in the precomputed variant + // Order of decryption shares is not important decryption_shares.shuffle(rng); // Decrypt with precomputed variant @@ -307,7 +323,6 @@ mod test_dkg_full { ) { let rng = &mut test_rng(); let security_threshold = shares_num / 2 + 1; - let (dkg, validator_keypairs, messages) = setup_dealt_dkg_with_n_validators( security_threshold, @@ -319,18 +334,21 @@ mod test_dkg_full { .take(shares_num as usize) .map(|m| m.1.clone()) .collect::>(); - let public_key = AggregatedTranscript::from_transcripts(&transcripts) - .unwrap() - .public_key; + let local_aggregate = + AggregatedTranscript::from_transcripts(&transcripts).unwrap(); + assert!(local_aggregate + .aggregate + .verify_aggregation(&dkg, &transcripts) + .unwrap()); let ciphertext = ferveo_tdec::encrypt::( SecretBox::new(MSG.to_vec()), AAD, - &public_key, + &local_aggregate.public_key, rng, ) .unwrap(); - let (pvss_aggregated, decryption_shares, _) = + let (local_aggregate, decryption_shares, _) = create_shared_secret_simple_tdec( &dkg, AAD, @@ -340,7 +358,7 @@ mod test_dkg_full { ); izip!( - &pvss_aggregated.aggregate.shares, + &local_aggregate.aggregate.shares, &validator_keypairs, &decryption_shares, ) @@ -362,7 +380,7 @@ mod test_dkg_full { let mut with_bad_decryption_share = decryption_share.clone(); with_bad_decryption_share.decryption_share = TargetField::zero(); assert!(!with_bad_decryption_share.verify( - &pvss_aggregated.aggregate.shares[0], + &local_aggregate.aggregate.shares[0], &validator_keypairs[0].public_key().encryption_key, &dkg.pvss_params.h, &ciphertext, @@ -372,7 +390,7 @@ mod test_dkg_full { let mut with_bad_checksum = decryption_share; with_bad_checksum.validator_checksum.checksum = G1Affine::zero(); assert!(!with_bad_checksum.verify( - &pvss_aggregated.aggregate.shares[0], + &local_aggregate.aggregate.shares[0], &validator_keypairs[0].public_key().encryption_key, &dkg.pvss_params.h, &ciphertext, @@ -388,7 +406,6 @@ mod test_dkg_full { ) { let rng = &mut test_rng(); let security_threshold = shares_num / 2 + 1; - let (dkg, validator_keypairs, messages) = setup_dealt_dkg_with_n_validators( security_threshold, @@ -400,13 +417,16 @@ mod test_dkg_full { .take(shares_num as usize) .map(|m| m.1.clone()) .collect::>(); - let public_key = AggregatedTranscript::from_transcripts(&transcripts) - .unwrap() - .public_key; + let local_aggregate = + AggregatedTranscript::from_transcripts(&transcripts).unwrap(); + assert!(local_aggregate + .aggregate + .verify_aggregation(&dkg, &transcripts) + .unwrap()); let ciphertext = ferveo_tdec::encrypt::( SecretBox::new(MSG.to_vec()), AAD, - &public_key, + &local_aggregate.public_key, rng, ) .unwrap(); @@ -598,13 +618,16 @@ mod test_dkg_full { .take(shares_num as usize) .map(|m| m.1.clone()) .collect::>(); - let public_key = AggregatedTranscript::from_transcripts(&transcripts) - .unwrap() - .public_key; + let local_aggregate = + AggregatedTranscript::from_transcripts(&transcripts).unwrap(); + assert!(local_aggregate + .aggregate + .verify_aggregation(&dkg, &transcripts) + .unwrap()); let ciphertext = ferveo_tdec::encrypt::( SecretBox::new(MSG.to_vec()), AAD, - &public_key, + &local_aggregate.public_key, rng, ) .unwrap(); diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs index 700db9cb..8d1affeb 100644 --- a/ferveo/src/pvss.rs +++ b/ferveo/src/pvss.rs @@ -1,4 +1,4 @@ -use std::{hash::Hash, marker::PhantomData, ops::Mul}; +use std::{collections::HashMap, hash::Hash, marker::PhantomData, ops::Mul}; use ark_ec::{pairing::Pairing, AffineRepr, CurveGroup, Group}; use ark_ff::{Field, Zero}; @@ -358,16 +358,16 @@ impl PubliclyVerifiableSS { /// Make a decryption share (precomputed variant) for a given ciphertext /// With this method, we wrap the PrivateKeyShare method to avoid exposing the private key share // TODO: Consider deprecating to use PrivateKeyShare method directly - pub fn create_decryption_share_simple_precomputed( + pub fn create_decryption_share_precomputed( &self, ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points: &HashMap>, ) -> Result> { self.decrypt_private_key_share(validator_keypair, share_index)? - .create_decryption_share_simple_precomputed( + .create_decryption_share_precomputed( ciphertext_header, aad, validator_keypair, diff --git a/ferveo/src/refresh.rs b/ferveo/src/refresh.rs index 0b8a95ef..d7700cfa 100644 --- a/ferveo/src/refresh.rs +++ b/ferveo/src/refresh.rs @@ -102,21 +102,51 @@ impl PrivateKeyShare { .map_err(|e| e.into()) } - pub fn create_decryption_share_simple_precomputed( + /// In precomputed variant, we offload some of the decryption related computation to the server-side: + /// We use the `prepare_combine_simple` function to precompute the lagrange coefficients + pub fn create_decryption_share_precomputed( &self, ciphertext_header: &CiphertextHeader, aad: &[u8], validator_keypair: &Keypair, share_index: u32, - domain_points: &[DomainPoint], + domain_points_map: &HashMap>, ) -> Result> { - let g_inv = PubliclyVerifiableParams::::default().g_inv(); - // In precomputed variant, we offload some of the decryption related computation to the server-side: - // We use the `prepare_combine_simple` function to precompute the lagrange coefficients - let lagrange_coeffs = prepare_combine_simple::(domain_points); - let lagrange_coeff = &lagrange_coeffs - .get(share_index as usize) + // We need to turn the domain points into a vector, and sort it by share index + let mut domain_points = domain_points_map + .iter() + .map(|(share_index, domain_point)| (*share_index, *domain_point)) + .collect::>(); + domain_points.sort_by_key(|(share_index, _)| *share_index); + + // Now, we have to pass the domain points to the `prepare_combine_simple` function + // and use the resulting lagrange coefficients to create the decryption share + + let only_domain_points = domain_points + .iter() + .map(|(_, domain_point)| *domain_point) + .collect::>(); + let lagrange_coeffs = prepare_combine_simple::(&only_domain_points); + + // Before we pick the lagrange coefficient for the current share index, we need + // to map the share index to the index in the domain points vector + // Given that we sorted the domain points by share index, the first element in the vector + // will correspond to the smallest share index, second to the second smallest, and so on + + let sorted_share_indices = domain_points + .iter() + .enumerate() + .map(|(adjusted_share_index, (share_index, _))| { + (*share_index, adjusted_share_index) + }) + .collect::>(); + let adjusted_share_index = *sorted_share_indices + .get(&share_index) .ok_or(Error::InvalidShareIndex(share_index))?; + + // Finally, pick the lagrange coefficient for the current share index + let lagrange_coeff = &lagrange_coeffs[adjusted_share_index]; + let g_inv = PubliclyVerifiableParams::::default().g_inv(); DecryptionSharePrecomputed::create( share_index as usize, &validator_keypair.decryption_key, @@ -368,8 +398,8 @@ mod tests_refresh { let security_threshold = shares_num * 2 / 3; let (_, _, mut contexts) = setup_simple::( - security_threshold as usize, shares_num as usize, + security_threshold as usize, rng, ); @@ -447,8 +477,8 @@ mod tests_refresh { let security_threshold = shares_num * 2 / 3; let (_, shared_private_key, mut contexts) = setup_simple::( - security_threshold as usize, shares_num as usize, + security_threshold as usize, rng, ); @@ -537,7 +567,7 @@ mod tests_refresh { let security_threshold = shares_num * 2 / 3; let (_, private_key_share, contexts) = - setup_simple::(security_threshold, shares_num, rng); + setup_simple::(shares_num, security_threshold, rng); let domain_points = &contexts .iter() .map(|ctxt| { diff --git a/ferveo/src/test_common.rs b/ferveo/src/test_common.rs index eea3a7da..a1a3d130 100644 --- a/ferveo/src/test_common.rs +++ b/ferveo/src/test_common.rs @@ -90,8 +90,6 @@ pub fn setup_dealt_dkg() -> DealtTestSetup { setup_dealt_dkg_with(SECURITY_THRESHOLD, SHARES_NUM) } -// TODO: Rewrite setup_utils to return messages separately - pub fn setup_dealt_dkg_with( security_threshold: u32, shares_num: u32,