Skip to content

Commit

Permalink
chore: make more add/sub test use variable num_blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Oct 25, 2024
1 parent e9af460 commit 7209633
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 128 deletions.
4 changes: 1 addition & 3 deletions tfhe/benches/core_crypto/ks_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,7 @@ mod cuda {

#[cfg(feature = "gpu")]
use cuda::cuda_keyswitch_group;
use tfhe::shortint::parameters::{
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use tfhe::shortint::parameters::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;

pub fn keyswitch_group() {
let mut criterion: Criterion<_> = (Criterion::default()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ where
for num_blocks in 1..MAX_NB_CTXT {
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if modulus == 1 {
// Basically have one bit the sign bit can't really test
// Basically have one bit, the sign bit can't really test
continue;
}

Expand Down
208 changes: 109 additions & 99 deletions tfhe/src/integer/server_key/radix_parallel/tests_signed/test_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::integer::server_key::radix_parallel::tests_signed::{
};
use crate::integer::server_key::radix_parallel::tests_unsigned::{
nb_tests_for_params, nb_tests_smaller_for_params, nb_unchecked_tests_for_params,
CpuFunctionExecutor,
CpuFunctionExecutor, MAX_NB_CTXT,
};
use crate::integer::tests::create_parametrized_test;
use crate::integer::{
Expand Down Expand Up @@ -93,111 +93,116 @@ where

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_suv for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for num_blocks in 1..MAX_NB_CTXT {
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if modulus == 1 {
// Basically have one bit, the sign bit can't really test
continue;
}

for _ in 0..nb_tests_smaller {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let clear_lhs = signed_add_under_modulus(clear_0, clear_2, modulus);
let clear_rhs = signed_add_under_modulus(clear_1, clear_3, modulus);

let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: i64 = cks.decrypt_signed(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check, \n\n\n msg0: {clear_0}, msg1: {clear_1}, \n\n\nct0: {ctxt_0:?}, \n\n\nct1: {ctxt_1:?}\n\n\n");

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
"Invalid overflow flag result for overflowing_suv for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);

for _ in 0..nb_tests_smaller {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let clear_lhs = signed_add_under_modulus(clear_0, clear_2, modulus);
let clear_rhs = signed_add_under_modulus(clear_1, clear_3, modulus);

let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: i64 = cks.decrypt_signed(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);

let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
}

// Test with trivial inputs, as it was bugged at some point
for _ in 0..4 {
// Reduce maximum value of random number such that at least the last block is a trivial 0
// (This is how the reproducing case was found)
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
// Test with trivial inputs, as it was bugged at some point
for _ in 0..4 {
// Reduce maximum value of random number such that at least the last block is a trivial
// 0 (This is how the reproducing case was found)
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);

let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));

let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
}

Expand Down Expand Up @@ -442,37 +447,42 @@ where

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;

executor.setup(&cks, sks);

let mut clear;

for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
for num_blocks in 1..MAX_NB_CTXT {
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
if modulus == 1 {
// Basically have one bit, the sign bit can't really test
continue;
}

let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
let tmp_ct = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct);
for _ in 0..nb_tests_smaller {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;

clear = signed_sub_under_modulus(clear_0, clear_1, modulus);
let ctxt_0 = cks.as_ref().encrypt_signed_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_signed_radix(clear_1, num_blocks);

// sub multiple times to raise the degree
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&ct_res, &ctxt_0));
let mut ct_res = executor.execute((&ctxt_0, &ctxt_1));
let tmp_ct = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
clear = signed_sub_under_modulus(clear, clear_0, modulus);
assert_eq!(ct_res, tmp_ct);

clear = signed_sub_under_modulus(clear_0, clear_1, modulus);

// sub multiple times to raise the degree
for _ in 0..nb_tests_smaller {
ct_res = executor.execute((&ct_res, &ctxt_0));
assert!(ct_res.block_carries_are_empty());
clear = signed_sub_under_modulus(clear, clear_0, modulus);

let dec_res: i64 = cks.decrypt_signed(&ct_res);
let dec_res: i64 = cks.decrypt_signed(&ct_res);

// println!("clear = {}, dec_res = {}", clear, dec_res);
assert_eq!(clear, dec_res);
// println!("clear = {}, dec_res = {}", clear, dec_res);
assert_eq!(clear, dec_res);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,16 +565,16 @@ where

let mut rng = rand::thread_rng();

let modulus = unsigned_modulus(cks.parameters().message_modulus(), NB_CTXT as u32);

executor.setup(&cks, sks.clone());

for _ in 0..nb_tests_smaller {
for num_blocks in 1..MAX_NB_CTXT {
let modulus = unsigned_modulus(cks.parameters().message_modulus(), num_blocks as u32);

let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;

let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);
let ctxt_0 = cks.as_ref().encrypt_radix(clear_0, num_blocks);
let ctxt_1 = cks.as_ref().encrypt_radix(clear_1, num_blocks);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
Expand Down Expand Up @@ -642,6 +642,7 @@ where
}

// Test with trivial inputs
let modulus = unsigned_modulus(cks.parameters().message_modulus(), NB_CTXT as u32);
for _ in 0..4 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,33 +311,35 @@ where

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;

executor.setup(&cks, sks);

for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;
for num_blocks in 1..MAX_NB_CTXT {
// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32) as u64;

let ctxt_1 = cks.encrypt(clear1);
let ctxt_2 = cks.encrypt(clear2);
for _ in 0..nb_tests_smaller {
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;

let mut res = ctxt_1.clone();
let mut clear = clear1;
let ctxt_1 = cks.as_ref().encrypt_radix(clear1, num_blocks);
let ctxt_2 = cks.as_ref().encrypt_radix(clear2, num_blocks);

// Subtract multiple times to raise the degree
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&res, &ctxt_2));
res = executor.execute((&res, &ctxt_2));
assert!(res.block_carries_are_empty());
assert_eq!(res, tmp);
let mut res = ctxt_1.clone();
let mut clear = clear1;

// Subtract multiple times to raise the degree
for _ in 0..nb_tests_smaller {
let tmp = executor.execute((&res, &ctxt_2));
res = executor.execute((&res, &ctxt_2));
assert!(res.block_carries_are_empty());
assert_eq!(res, tmp);

panic_if_any_block_is_not_clean(&res, &cks);
panic_if_any_block_is_not_clean(&res, &cks);

clear = (clear.wrapping_sub(clear2)) % modulus;
let dec: u64 = cks.decrypt(&res);
assert_eq!(clear, dec);
clear = (clear.wrapping_sub(clear2)) % modulus;
let dec: u64 = cks.decrypt(&res);
assert_eq!(clear, dec);
}
}
}
}
Expand Down

0 comments on commit 7209633

Please sign in to comment.