Skip to content

Commit

Permalink
refactor(core): rename LweBskGroupingFactor MultiBitGroupingFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed May 13, 2024
1 parent 3c3b4ca commit 7746bb4
Show file tree
Hide file tree
Showing 44 changed files with 444 additions and 439 deletions.
6 changes: 3 additions & 3 deletions tfhe/benches/core_crypto/dev_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn get_bench_params<Scalar: Numeric>() -> (
DecompositionLevelCount,
GlweDimension,
PolynomialSize,
LweBskGroupingFactor,
MultiBitGroupingFactor,
ThreadCount,
) {
if Scalar::BITS == 64 {
Expand All @@ -35,7 +35,7 @@ fn get_bench_params<Scalar: Numeric>() -> (
DecompositionLevelCount(5),
GlweDimension(1),
PolynomialSize(1024),
LweBskGroupingFactor(2),
MultiBitGroupingFactor(2),
ThreadCount(5),
)
} else if Scalar::BITS == 32 {
Expand All @@ -46,7 +46,7 @@ fn get_bench_params<Scalar: Numeric>() -> (
DecompositionLevelCount(1),
GlweDimension(3),
PolynomialSize(512),
LweBskGroupingFactor(2),
MultiBitGroupingFactor(2),
ThreadCount(5),
)
} else {
Expand Down
7 changes: 5 additions & 2 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ fn throughput_benchmark_parameters<Scalar: UnsignedInteger>(
}
}

fn multi_bit_benchmark_parameters<Scalar: UnsignedInteger + Default>(
) -> Vec<(String, CryptoParametersRecord<Scalar>, LweBskGroupingFactor)> {
fn multi_bit_benchmark_parameters<Scalar: UnsignedInteger + Default>() -> Vec<(
String,
CryptoParametersRecord<Scalar>,
MultiBitGroupingFactor,
)> {
if Scalar::BITS == 64 {
let parameters = if cfg!(feature = "gpu") {
vec![
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/c_api/core_crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ pub unsafe extern "C" fn core_crypto_lwe_multi_bit_bootstrapping_key_element_siz
let output_glwe_sk_poly_size = PolynomialSize(output_glwe_sk_poly_size);

let lwe_multi_bit_level_count = DecompositionLevelCount(lwe_multi_bit_level_count);
let lwe_multi_bit_grouping_factor = LweBskGroupingFactor(lwe_multi_bit_grouping_factor);
let lwe_multi_bit_grouping_factor = MultiBitGroupingFactor(lwe_multi_bit_grouping_factor);

*result = lwe_multi_bit_bootstrap_key_size(
input_lwe_sk_dim,
Expand Down Expand Up @@ -472,7 +472,7 @@ pub unsafe extern "C" fn core_crypto_par_generate_lwe_multi_bit_bootstrapping_ke

let lwe_multi_bit_base_log = DecompositionBaseLog(lwe_multi_bit_base_log);
let lwe_multi_bit_level_count = DecompositionLevelCount(lwe_multi_bit_level_count);
let lwe_multi_bit_grouping_factor = LweBskGroupingFactor(lwe_multi_bit_grouping_factor);
let lwe_multi_bit_grouping_factor = MultiBitGroupingFactor(lwe_multi_bit_grouping_factor);

let lwe_multi_bit_slice_len = {
let bsk = LweMultiBitBootstrapKeyOwned::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use rayon::prelude::*;
/// let polynomial_size = PolynomialSize(1024);
/// let glwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let grouping_factor = LweBskGroupingFactor(2);
/// let grouping_factor = MultiBitGroupingFactor(2);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
Expand Down Expand Up @@ -192,7 +192,7 @@ pub fn allocate_and_generate_new_lwe_multi_bit_bootstrap_key<
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
Expand Down Expand Up @@ -242,7 +242,7 @@ where
/// let polynomial_size = PolynomialSize(1024);
/// let glwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let grouping_factor = LweBskGroupingFactor(2);
/// let grouping_factor = MultiBitGroupingFactor(2);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
Expand Down Expand Up @@ -453,7 +453,7 @@ pub fn par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key<
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
Expand Down Expand Up @@ -638,7 +638,7 @@ pub fn allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
Expand Down Expand Up @@ -799,7 +799,7 @@ pub fn par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub use super::lwe_programmable_bootstrapping::generate_programmable_bootstrap_g

pub(crate) fn modulus_switch_multi_bit<Scalar>(
ciphertext_modulus_log: CiphertextModulusLog,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
lwe_mask_elements: &[Scalar],
) -> impl Iterator<Item = usize> + '_
where
Expand All @@ -46,7 +46,7 @@ where
// Returns an iterator of booleans (as usize), corresponding to successive mask group elements
// to indicate if they must be used at the given power_set_index
pub(crate) fn selection_bit(
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
power_set_index: usize,
) -> impl Iterator<Item = usize> {
debug_assert!(power_set_index < grouping_factor.multi_bit_power_set_size().0);
Expand All @@ -72,7 +72,7 @@ pub struct StandardMultiBitModulusSwitchedCt<
C: Container<Element = Scalar> + Sync,
> {
pub input: &'a LweCiphertext<C>,
pub grouping_factor: LweBskGroupingFactor,
pub grouping_factor: MultiBitGroupingFactor,
pub log_modulus: CiphertextModulusLog,
}

Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn prepare_multi_bit_ggsw_mem_optimized<GgswBufferCont, GgswGroupCont, Fouri
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let pbs_base_log = DecompositionBaseLog(23);
/// let pbs_level = DecompositionLevelCount(1);
/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs
/// let grouping_factor = MultiBitGroupingFactor(2); // Group bits in pairs
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Request the best seeder possible, starting with hardware entropy sources and falling back to
Expand Down Expand Up @@ -843,7 +843,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
/// Gaussian::from_dispersion_parameter(StandardDev(0.00000000000000029403601535432533), 0.0);
/// let pbs_base_log = DecompositionBaseLog(23);
/// let pbs_level = DecompositionLevelCount(1);
/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs
/// let grouping_factor = MultiBitGroupingFactor(2); // Group bits in pairs
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Request the best seeder possible, starting with hardware entropy sources and falling back to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::core_crypto::commons::math::random::{
};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension,
LweBskGroupingFactor, LweDimension, PolynomialSize,
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension,
MultiBitGroupingFactor, PolynomialSize,
};
use crate::core_crypto::commons::test_tools::new_secret_random_generator;
use crate::core_crypto::entities::*;
Expand Down Expand Up @@ -36,7 +36,7 @@ fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence<
let base_log = DecompositionBaseLog(
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
);
let grouping_factor = LweBskGroupingFactor(
let grouping_factor = MultiBitGroupingFactor(
crate::core_crypto::commons::test_tools::random_usize_between(2..4),
);
let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
Expand Down
12 changes: 6 additions & 6 deletions tfhe/src/core_crypto/algorithms/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ pub const MULTI_BIT_2_2_2_PARAMS: MultiBitTestParams<u64> = MultiBitTestParams {
)),
message_modulus_log: MessageModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(2),
grouping_factor: MultiBitGroupingFactor(2),
thread_count: ThreadCount(5),
};

Expand All @@ -210,7 +210,7 @@ pub const MULTI_BIT_3_3_2_PARAMS: MultiBitTestParams<u64> = MultiBitTestParams {
)),
message_modulus_log: MessageModulusLog(6),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(2),
grouping_factor: MultiBitGroupingFactor(2),
thread_count: ThreadCount(5),
};

Expand All @@ -228,7 +228,7 @@ pub const MULTI_BIT_2_2_2_CUSTOM_MOD_PARAMS: MultiBitTestParams<u64> = MultiBitT
)),
message_modulus_log: MessageModulusLog(3),
ciphertext_modulus: CiphertextModulus::new(1 << 63),
grouping_factor: LweBskGroupingFactor(2),
grouping_factor: MultiBitGroupingFactor(2),
thread_count: ThreadCount(5),
};

Expand All @@ -246,7 +246,7 @@ pub const MULTI_BIT_2_2_3_PARAMS: MultiBitTestParams<u64> = MultiBitTestParams {
)),
message_modulus_log: MessageModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
grouping_factor: MultiBitGroupingFactor(3),
thread_count: ThreadCount(12),
};

Expand All @@ -264,7 +264,7 @@ pub const MULTI_BIT_3_3_3_PARAMS: MultiBitTestParams<u64> = MultiBitTestParams {
)),
message_modulus_log: MessageModulusLog(6),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
grouping_factor: MultiBitGroupingFactor(3),
thread_count: ThreadCount(5),
};

Expand All @@ -282,7 +282,7 @@ pub const MULTI_BIT_2_2_3_CUSTOM_MOD_PARAMS: MultiBitTestParams<u64> = MultiBitT
)),
message_modulus_log: MessageModulusLog(3),
ciphertext_modulus: CiphertextModulus::new(1 << 63),
grouping_factor: LweBskGroupingFactor(3),
grouping_factor: MultiBitGroupingFactor(3),
thread_count: ThreadCount(12),
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn assert_ms_compression<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usiz
fn assert_ms_multi_bit_compression<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>>(
ct: &LweCiphertext<Vec<Scalar>>,
log_modulus: CiphertextModulusLog,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) {
let a = StandardMultiBitModulusSwitchedCt {
input: ct,
Expand Down Expand Up @@ -124,7 +124,7 @@ fn test_ms_with_packing() {

assert_ms_compression(&lwe_ciphertext_in, log_modulus);

for grouping_factor in (1..6).map(LweBskGroupingFactor) {
for grouping_factor in (1..6).map(MultiBitGroupingFactor) {
if lwe_dimension.0 % grouping_factor.0 == 0 {
assert_ms_multi_bit_compression(
&lwe_ciphertext_in,
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/algorithms/test/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub struct MultiBitTestParams<Scalar: UnsignedInteger> {
pub glwe_noise_distribution: DynamicDistribution<Scalar>,
pub message_modulus_log: MessageModulusLog,
pub ciphertext_modulus: CiphertextModulus<Scalar>,
pub grouping_factor: LweBskGroupingFactor,
pub grouping_factor: MultiBitGroupingFactor,
pub thread_count: ThreadCount,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::core_crypto::commons::math::random::{
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, GlweDimension,
GlweSize, LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweMaskCount, LweSize,
GlweSize, LweCiphertextCount, LweDimension, LweMaskCount, LweSize, MultiBitGroupingFactor,
PolynomialSize,
};
use concrete_csprng::generators::ForkError;
Expand Down Expand Up @@ -80,7 +80,7 @@ impl<G: ByteRandomGenerator> MaskRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl Iterator<Item = Self>, ForkError> {
let mask_bytes = mask_elements_per_multi_bit_bsk_ggsw_group(
level,
Expand All @@ -99,7 +99,7 @@ impl<G: ByteRandomGenerator> MaskRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl Iterator<Item = Self>, ForkError> {
let ggsw_count = grouping_factor.multi_bit_power_set_size();
let mask_bytes = mask_elements_per_ggsw(level, glwe_size, polynomial_size)
Expand Down Expand Up @@ -233,7 +233,7 @@ impl<G: ParallelByteRandomGenerator> MaskRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = Self>, ForkError> {
let mask_bytes = mask_elements_per_multi_bit_bsk_ggsw_group(
level,
Expand All @@ -252,7 +252,7 @@ impl<G: ParallelByteRandomGenerator> MaskRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = Self>, ForkError> {
let ggsw_count = grouping_factor.multi_bit_power_set_size();
let mask_bytes = mask_elements_per_ggsw(level, glwe_size, polynomial_size)
Expand Down Expand Up @@ -416,7 +416,7 @@ fn mask_elements_per_multi_bit_bsk_ggsw_group(
level: DecompositionLevelCount,
glwe_size: GlweSize,
poly_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> MaskElementCount {
MaskElementCount(
grouping_factor.multi_bit_power_set_size().0
Expand Down
11 changes: 6 additions & 5 deletions tfhe/src/core_crypto/commons/generators/encryption/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::core_crypto::commons::math::random::{
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, GlweSize,
LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweMaskCount, LweSize, PolynomialSize,
LweCiphertextCount, LweDimension, LweMaskCount, LweSize, MultiBitGroupingFactor,
PolynomialSize,
};
use concrete_csprng::generators::ForkError;
use mask_random_generator::MaskRandomGenerator;
Expand Down Expand Up @@ -67,7 +68,7 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl Iterator<Item = Self>, ForkError> {
let mask_iter = self.mask.fork_multi_bit_bsk_to_ggsw_group::<T>(
lwe_dimension,
Expand All @@ -93,7 +94,7 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl Iterator<Item = Self>, ForkError> {
let mask_iter = self.mask.fork_multi_bit_bsk_ggsw_group_to_ggsw::<T>(
level,
Expand Down Expand Up @@ -378,7 +379,7 @@ impl<G: ParallelByteRandomGenerator> EncryptionRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = Self>, ForkError> {
let mask_iter = self.mask.par_fork_multi_bit_bsk_to_ggsw_group::<T>(
lwe_dimension,
Expand All @@ -404,7 +405,7 @@ impl<G: ParallelByteRandomGenerator> EncryptionRandomGenerator<G> {
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
grouping_factor: MultiBitGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = Self>, ForkError> {
let mask_iter = self.mask.par_fork_multi_bit_bsk_ggsw_group_to_ggsw::<T>(
level,
Expand Down
Loading

0 comments on commit 7746bb4

Please sign in to comment.