Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mz/rename multi bit #1124

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

template <typename Torus, class params>
__device__ Torus calculates_monomial_degree(Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t power_set_index,
uint32_t grouping_factor) {
Torus x = 0;
for (int i = 0; i < grouping_factor; i++) {
uint32_t mask_position = grouping_factor - (i + 1);
int selection_bit = (ggsw_idx >> mask_position) & 1;
int selection_bit = (power_set_index >> mask_position) & 1;
x += selection_bit * lwe_array_group[i];
}

Expand Down
6 changes: 3 additions & 3 deletions tfhe/benches/core_crypto/dev_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn get_bench_params<Scalar: Numeric>() -> (
DecompositionLevelCount,
GlweDimension,
PolynomialSize,
LweBskGroupingFactor,
MultiBitGroupingFactor,
ThreadCount,
) {
if Scalar::BITS == 64 {
Expand All @@ -35,7 +35,7 @@ fn get_bench_params<Scalar: Numeric>() -> (
DecompositionLevelCount(5),
GlweDimension(1),
PolynomialSize(1024),
LweBskGroupingFactor(2),
MultiBitGroupingFactor(2),
ThreadCount(5),
)
} else if Scalar::BITS == 32 {
Expand All @@ -46,7 +46,7 @@ fn get_bench_params<Scalar: Numeric>() -> (
DecompositionLevelCount(1),
GlweDimension(3),
PolynomialSize(512),
LweBskGroupingFactor(2),
MultiBitGroupingFactor(2),
ThreadCount(5),
)
} else {
Expand Down
7 changes: 5 additions & 2 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ fn throughput_benchmark_parameters<Scalar: UnsignedInteger>(
}
}

fn multi_bit_benchmark_parameters<Scalar: UnsignedInteger + Default>(
) -> Vec<(String, CryptoParametersRecord<Scalar>, LweBskGroupingFactor)> {
fn multi_bit_benchmark_parameters<Scalar: UnsignedInteger + Default>() -> Vec<(
String,
CryptoParametersRecord<Scalar>,
MultiBitGroupingFactor,
)> {
if Scalar::BITS == 64 {
let parameters = if cfg!(feature = "gpu") {
vec![
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/c_api/core_crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ pub unsafe extern "C" fn core_crypto_lwe_multi_bit_bootstrapping_key_element_siz
let output_glwe_sk_poly_size = PolynomialSize(output_glwe_sk_poly_size);

let lwe_multi_bit_level_count = DecompositionLevelCount(lwe_multi_bit_level_count);
let lwe_multi_bit_grouping_factor = LweBskGroupingFactor(lwe_multi_bit_grouping_factor);
let lwe_multi_bit_grouping_factor = MultiBitGroupingFactor(lwe_multi_bit_grouping_factor);

*result = lwe_multi_bit_bootstrap_key_size(
input_lwe_sk_dim,
Expand Down Expand Up @@ -472,7 +472,7 @@ pub unsafe extern "C" fn core_crypto_par_generate_lwe_multi_bit_bootstrapping_ke

let lwe_multi_bit_base_log = DecompositionBaseLog(lwe_multi_bit_base_log);
let lwe_multi_bit_level_count = DecompositionLevelCount(lwe_multi_bit_level_count);
let lwe_multi_bit_grouping_factor = LweBskGroupingFactor(lwe_multi_bit_grouping_factor);
let lwe_multi_bit_grouping_factor = MultiBitGroupingFactor(lwe_multi_bit_grouping_factor);

let lwe_multi_bit_slice_len = {
let bsk = LweMultiBitBootstrapKeyOwned::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use rayon::prelude::*;
/// let polynomial_size = PolynomialSize(1024);
/// let glwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let grouping_factor = LweBskGroupingFactor(2);
/// let grouping_factor = MultiBitGroupingFactor(2);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
Expand Down Expand Up @@ -64,9 +64,9 @@ use rayon::prelude::*;
/// &mut encryption_generator,
/// );
///
/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
/// let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();
///
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip(
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(multi_bit_power_set_size.0).zip(
/// input_lwe_secret_key
/// .as_ref()
/// .chunks_exact(grouping_factor.0),
Expand Down Expand Up @@ -137,10 +137,10 @@ pub fn generate_lwe_multi_bit_bootstrap_key<
.unwrap();

let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.chunks_exact_mut(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -182,7 +182,7 @@ pub fn allocate_and_generate_new_lwe_multi_bit_bootstrap_key<
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
Expand Down Expand Up @@ -232,7 +232,7 @@ where
/// let polynomial_size = PolynomialSize(1024);
/// let glwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let grouping_factor = LweBskGroupingFactor(2);
/// let grouping_factor = MultiBitGroupingFactor(2);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
Expand Down Expand Up @@ -282,9 +282,9 @@ where
///
/// par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk);
///
/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
/// let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();
///
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip(
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(multi_bit_power_set_size.0).zip(
/// input_lwe_secret_key
/// .as_ref()
/// .chunks_exact(grouping_factor.0),
Expand Down Expand Up @@ -354,10 +354,10 @@ pub fn par_generate_lwe_multi_bit_bootstrap_key<
.unwrap();

let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

output
.par_chunks_exact_mut(ggsw_per_multi_bit_element.0)
.par_chunks_exact_mut(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -436,7 +436,7 @@ pub fn par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key<
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
Expand Down Expand Up @@ -559,10 +559,10 @@ pub fn generate_seeded_lwe_multi_bit_bootstrap_key<
.unwrap();

let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.chunks_exact_mut(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -612,7 +612,7 @@ pub fn allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
Expand Down Expand Up @@ -708,10 +708,10 @@ pub fn par_generate_seeded_lwe_multi_bit_bootstrap_key<
.unwrap();

let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

output
.par_chunks_exact_mut(ggsw_per_multi_bit_element.0)
.par_chunks_exact_mut(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -766,7 +766,7 @@ pub fn par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ pub use super::lwe_programmable_bootstrapping::generate_programmable_bootstrap_g

pub fn modulus_switch_multi_bit<Scalar>(
ciphertext_modulus_log: CiphertextModulusLog,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
lwe_mask_elements: &[Scalar],
) -> impl Iterator<Item = usize> + '_
where
Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
{
// Start at 1, the first ggsw is not rotated
(1..grouping_factor.ggsw_per_multi_bit_element().0).map(move |power_set_index| {
(1..grouping_factor.multi_bit_power_set_size().0).map(move |power_set_index| {
let mut monomial_degree = Scalar::ZERO;
for (&mask_element, selection_bit) in lwe_mask_elements
.iter()
Expand All @@ -46,10 +46,10 @@ where
// Returns an iterator of booleans (as usize), corresponding to successive mask group elements
// to indicate if they must be used at the given power_set_index
pub(crate) fn selection_bit(
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
power_set_index: usize,
) -> impl Iterator<Item = usize> {
debug_assert!(power_set_index < grouping_factor.ggsw_per_multi_bit_element().0);
debug_assert!(power_set_index < grouping_factor.multi_bit_power_set_size().0);

(0..grouping_factor.0).map(move |mask_idx| {
let mask_position = grouping_factor.0 - (mask_idx + 1);
Expand All @@ -72,7 +72,7 @@ pub struct StandardMultiBitModulusSwitchedCt<
C: Container<Element = Scalar> + Sync,
> {
pub input: &'a LweCiphertext<C>,
pub grouping_factor: LweBskGroupingFactor,
pub grouping_factor: MultiBitGroupingFactor,
pub log_modulus: CiphertextModulusLog,
}

Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn prepare_multi_bit_ggsw_mem_optimized<GgswBufferCont, GgswGroupCont, Fouri
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let pbs_base_log = DecompositionBaseLog(23);
/// let pbs_level = DecompositionLevelCount(1);
/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs
/// let grouping_factor = MultiBitGroupingFactor(2); // Group bits in pairs
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Request the best seeder possible, starting with hardware entropy sources and falling back to
Expand Down Expand Up @@ -417,7 +417,7 @@ pub fn multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -473,8 +473,8 @@ pub fn multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down Expand Up @@ -648,7 +648,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -699,8 +699,8 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down Expand Up @@ -843,7 +843,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let pbs_base_log = DecompositionBaseLog(23);
/// let pbs_level = DecompositionLevelCount(1);
/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs
/// let grouping_factor = MultiBitGroupingFactor(2); // Group bits in pairs
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Request the best seeder possible, starting with hardware entropy sources and falling back to
Expand Down Expand Up @@ -1236,7 +1236,7 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
let ggsw_vec: Vec<_> = multi_bit_bsk.iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -1312,8 +1312,8 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down Expand Up @@ -1502,7 +1502,7 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let ggsw_vec: Vec<_> = multi_bit_bsk.iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -1573,8 +1573,8 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ pub fn decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator<

// Forking logic must match multi bit BSK generation
let output_grouping_factor = output_bsk.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

let forking_config = input_bsk.decompression_fork_config(Uniform);

let gen_iter = generator.try_fork_from_config(forking_config).unwrap();

for ((mut output_ggsw_group, input_ggsw_group), mut loop_generator) in output_bsk
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.zip(input_bsk.chunks_exact(ggsw_per_multi_bit_element.0))
.chunks_exact_mut(multi_bit_power_set_size.0)
.zip(input_bsk.chunks_exact(multi_bit_power_set_size.0))
.zip(gen_iter)
{
let group_forking_config = input_ggsw_group.decompression_fork_config(Uniform);
Expand Down Expand Up @@ -121,15 +121,15 @@ pub fn par_decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator

// Forking logic must match multi bit BSK generation
let output_grouping_factor = output_bsk.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

let forking_config = input_bsk.decompression_fork_config(Uniform);

let gen_iter = generator.par_try_fork_from_config(forking_config).unwrap();

output_bsk
.par_chunks_exact_mut(ggsw_per_multi_bit_element.0)
.zip(input_bsk.par_chunks_exact(ggsw_per_multi_bit_element.0))
.par_chunks_exact_mut(multi_bit_power_set_size.0)
.zip(input_bsk.par_chunks_exact(multi_bit_power_set_size.0))
.zip(gen_iter)
.for_each(
|((mut output_ggsw_group, input_ggsw_group), mut loop_generator)| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::core_crypto::commons::math::random::{
};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension,
LweBskGroupingFactor, LweDimension, PolynomialSize,
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension,
MultiBitGroupingFactor, PolynomialSize,
};
use crate::core_crypto::commons::test_tools::new_secret_random_generator;
use crate::core_crypto::entities::*;
Expand Down Expand Up @@ -36,7 +36,7 @@ fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence<
let base_log = DecompositionBaseLog(
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
);
let grouping_factor = LweBskGroupingFactor(
let grouping_factor = MultiBitGroupingFactor(
crate::core_crypto::commons::test_tools::random_usize_between(2..4),
);
let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
Expand Down
Loading
Loading