diff --git a/mvpoly/src/prime.rs b/mvpoly/src/prime.rs index fc47280fbb..b775ef3efc 100644 --- a/mvpoly/src/prime.rs +++ b/mvpoly/src/prime.rs @@ -161,7 +161,8 @@ use rand::RngCore; use std::ops::{Index, IndexMut}; use crate::utils::{ - compute_all_two_factors_decomposition, naive_prime_factors, PrimeNumberGenerator, + compute_all_two_factors_decomposition, compute_indices_nested_loop, naive_prime_factors, + PrimeNumberGenerator, }; /// Represents a multivariate polynomial of degree less than `D` in `N` variables. @@ -439,6 +440,126 @@ impl Dense { }); result } + + /// Compute the cross-terms as described in [Behind Nova: cross-terms + /// computation for high degree + /// gates](https://hackmd.io/@dannywillems/Syo5MBq90) + /// + /// The polynomial must not necessarily be homogeneous. For this reason, the + /// values `u1` and `u2` represents the extra variable that is used to make + /// the polynomial homogeneous. + /// + /// The homogeneous degree is supposed to be the one defined by the type of + /// the polynomial, i.e. `D`. + /// + /// The output is a map of `D - 1` values that represents the cross-terms + /// for each power of `r`. + // IMPROVEME: Dummy implementation, a cache can be used to save the + // previously computed multiplications and powers. + // Maybe using a symbolic form could speed up the computation. + pub fn compute_cross_terms( + &self, + eval1: &[F; N], + eval2: &[F; N], + u1: F, + u2: F, + ) -> HashMap { + assert!( + D >= 2, + "The degree of the polynomial must be greater than 2" + ); + let mut cross_terms_by_powers_of_r: HashMap = HashMap::new(); + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + // Computing invididual contribution of each monomial + // FIXME: handle constant, prime decompo returns empty list of 1. + self.coeff.iter().enumerate().for_each(|(i, c)| { + // If the coefficient is zero, we skip the computation as the contribution + // is null. + if *c == F::zero() || i == 0 { + // FIXME: handle constant, prime decompo returns empty list of 1. + } else { + let idx = self.normalized_indices[i]; + let prime_decomposition = naive_prime_factors(idx, &mut prime_gen); + // Fetch individual degree + let degrees: Vec = prime_decomposition.iter().map(|(_, d)| *d).collect(); + // No cross-terms + // FIXME + if degrees.len() == 1 && (degrees[0] == 0 || degrees[0] == D) { + return; + } + // We compute the missing degree to homogeneize the polynomial. + // The variable u given as a parameter will be used to + // homogeneize the polynomial. + let u_degree: usize = D - degrees.iter().sum::(); + // Will be used to compute the nested sums + // It returns all the indices i_1, ..., i_k for the sums: + // Σ_{i_1 = 0}^{n_1} Σ_{i_2 = 0}^{n_2} ... Σ_{i_k = 0}^{n_k} + let indices = compute_indices_nested_loop(degrees.iter().map(|d| *d + 1).collect()); + // We treat the homogeneisation degree separately in the sum. It + // eases the code below + for i in 0..=u_degree { + // Add the binomial from the homogeneisation + // i.e (u_degree choose i) + let u_binomial_term = binomial(u_degree, i); + indices.iter().for_each(|indices| { + let sum_indices = indices.iter().sum::() + i; + // If the sum of the indices is 0 or D, we skip the + // computation as the contribution would go in the + // evaluation of the polynomial at each evaluation + // vectors eval1 and eval2 + if sum_indices == 0 || sum_indices == D { + return; + } + // Compute + // (n_1 choose i_1) * (n_2 choose i_2) * ... * (n_k choose i_k) + let binomial_term = indices + .iter() + .zip(degrees.iter()) + .fold(u_binomial_term, |acc, (i, d)| acc * binomial(*d, *i)); + let binomial_term = F::from(binomial_term as u64); + // Compute the product x_k^i_k + let monomial_eval1 = prime_decomposition.iter().zip(indices.iter()).fold( + F::one(), + |acc, ((p, _), i)| { + // Get the evaluation of the corresponding + // variable + let inv_p = primes + .iter() + .position(|x| x == p) + .expect("The variable must be in the list of primes"); + acc * eval1[inv_p].pow([*i as u64]) + }, + ); + // Compute the product x'_k^(n_k - i_k) + let monomial_eval2 = prime_decomposition.iter().zip(indices.iter()).fold( + F::one(), + |acc, ((p, d), i)| { + // Get the evaluation of the corresponding + // variable + let inv_p = primes + .iter() + .position(|x| x == p) + .expect("The variable must be in the list of primes"); + acc * eval2[inv_p].pow([(d - *i) as u64]) + }, + ); + // u1^i * u2^(u_degree - i) + let u = u1.pow([i as u64]) * u2.pow([(u_degree - i) as u64]); + let res = binomial_term * monomial_eval1 * monomial_eval2 * u; + let res = *c * res; + // power of r is Σ (n_k - i_k) + let power_r: usize = D - sum_indices; + cross_terms_by_powers_of_r + .entry(power_r) + .and_modify(|e| *e += res) + .or_insert(res); + }); + } + } + }); + cross_terms_by_powers_of_r + } } impl Default for Dense { diff --git a/mvpoly/tests/prime.rs b/mvpoly/tests/prime.rs index 59fd4834c7..03dab22f90 100644 --- a/mvpoly/tests/prime.rs +++ b/mvpoly/tests/prime.rs @@ -802,3 +802,30 @@ fn test_mvpoly_mul_pbt() { let p2 = unsafe { Dense::::random(&mut rng, Some(max_degree)) }; assert_eq!(p1.clone() * p2.clone(), p2.clone() * p1.clone()); } + +#[test] +fn test_mvpoly_compute_cross_terms_eval_zero() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Dense::::random(&mut rng, None) }; + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let eval: [Fp; 4] = [Fp::zero(); 4]; + let res = p1.compute_cross_terms(&eval, &eval, u1, u2); + res.iter() + .for_each(|(_power, cross_term)| assert_eq!(*cross_term, Fp::zero())); +} + +// FIXME +#[test] +fn test_mvpoly_compute_cross_terms_degree_two() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Dense::::random(&mut rng, None) }; + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let res = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + // We only have one cross-term in this case + assert_eq!(res.len(), 1); + println!("{:?}", res); +}