From dc8d41cea0b25ac1f3355d2838fd59c956abfb2e Mon Sep 17 00:00:00 2001 From: Danny Willems Date: Thu, 29 Aug 2024 18:44:42 -0700 Subject: [PATCH] MVPoly: util to compute nested loops indices It will be used later in the cross-term computation to perform the multi-nomial sum --- mvpoly/src/utils.rs | 47 +++++++++++++++++++++++++++ mvpoly/tests/utils.rs | 75 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/mvpoly/src/utils.rs b/mvpoly/src/utils.rs index 098ec6baac..0cd104d60d 100644 --- a/mvpoly/src/utils.rs +++ b/mvpoly/src/utils.rs @@ -219,3 +219,50 @@ pub fn compute_all_two_factors_decomposition( factors } } + +/// Compute the list of indices to perform N nested loops of different size each. +/// In other words, if we have to perform the 3 nested loops: +/// ```rust +/// let n1 = 3; +/// let n2 = 3; +/// let n3 = 5; +/// for i in 0..n1 { +/// for j in 0..n2 { +/// for k in 0..n3 { +/// } +/// } +/// } +/// ``` +/// the output will be all the possible values of `i`, `j`, and `k`. +/// The algorithm is as follows: +/// ```rust +/// for l in 0..(n1 * n2 * n3) { +/// i = l % n1; +/// j = (l / n1) % n2; +/// k = (l / (n1 * n2)) % n3; +/// } +/// ``` +/// For N nested loops, the algorithm is the same, with the division increasing +/// by the factor `N_k` for the index `i_(k + 1)` +pub fn compute_indices_nested_loop(nested_loop_sizes: Vec) -> Vec> { + let n = nested_loop_sizes.iter().product(); + (0..n) + .map(|i| { + let mut div = 1; + // Compute indices for the loop, step i + let indices: Vec = nested_loop_sizes + .iter() + .map(|n_i| { + let k = (i / div) % n_i; + div *= n_i; + k + }) + .collect(); + assert!( + div == n, + "The division must be equal to the number of terms at the end" + ); + indices + }) + .collect() +} diff --git a/mvpoly/tests/utils.rs b/mvpoly/tests/utils.rs index 76f0475cdb..c4d98d117d 100644 --- a/mvpoly/tests/utils.rs +++ b/mvpoly/tests/utils.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use mvpoly::utils::{ - compute_all_two_factors_decomposition, get_mapping_with_primes, is_prime, naive_prime_factors, - PrimeNumberGenerator, + compute_all_two_factors_decomposition, compute_indices_nested_loop, get_mapping_with_primes, + is_prime, naive_prime_factors, PrimeNumberGenerator, }; pub const FIRST_FIFTY_PRIMES: [usize; 50] = [ @@ -225,3 +225,74 @@ pub fn test_compute_all_two_factors_decomposition_with_multiplicity() { .to_vec() ); } + +#[test] +pub fn test_compute_indices_nested_loop() { + let nested_loops = vec![2, 2]; + // sorting to get the same order + let mut exp_indices = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]]; + exp_indices.sort(); + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); + + let nested_loops = vec![3, 2]; + // sorting to get the same order + let mut exp_indices = vec![ + vec![0, 0], + vec![0, 1], + vec![1, 0], + vec![1, 1], + vec![2, 0], + vec![2, 1], + ]; + exp_indices.sort(); + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); + + let nested_loops = vec![3, 3, 2, 2]; + // sorting to get the same order + let mut exp_indices = vec![ + vec![0, 0, 0, 0], + vec![0, 0, 0, 1], + vec![0, 0, 1, 0], + vec![0, 0, 1, 1], + vec![0, 1, 0, 0], + vec![0, 1, 0, 1], + vec![0, 1, 1, 0], + vec![0, 1, 1, 1], + vec![0, 2, 0, 0], + vec![0, 2, 0, 1], + vec![0, 2, 1, 0], + vec![0, 2, 1, 1], + vec![1, 0, 0, 0], + vec![1, 0, 0, 1], + vec![1, 0, 1, 0], + vec![1, 0, 1, 1], + vec![1, 1, 0, 0], + vec![1, 1, 0, 1], + vec![1, 1, 1, 0], + vec![1, 1, 1, 1], + vec![1, 2, 0, 0], + vec![1, 2, 0, 1], + vec![1, 2, 1, 0], + vec![1, 2, 1, 1], + vec![2, 0, 0, 0], + vec![2, 0, 0, 1], + vec![2, 0, 1, 0], + vec![2, 0, 1, 1], + vec![2, 1, 0, 0], + vec![2, 1, 0, 1], + vec![2, 1, 1, 0], + vec![2, 1, 1, 1], + vec![2, 2, 0, 0], + vec![2, 2, 0, 1], + vec![2, 2, 1, 0], + vec![2, 2, 1, 1], + ]; + exp_indices.sort(); + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); +}