Skip to content

Commit

Permalink
MVPoly: util to compute nested loops indices
Browse files Browse the repository at this point in the history
It will be used later in the cross-term computation to perform the multi-nomial
sum
  • Loading branch information
dannywillems committed Aug 31, 2024
1 parent 3f3897b commit d11c543
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 2 deletions.
51 changes: 51 additions & 0 deletions mvpoly/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,54 @@ pub fn compute_all_two_factors_decomposition(
factors
}
}

/// Compute the list of indices to perform N nested loops of different size each.
/// In other words, if we have to perform the 3 nested loops:
/// ```rust
/// let n1 = 3;
/// let n2 = 3;

Check warning on line 227 in mvpoly/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

mvpoly/src/utils.rs#L223-L227

Added lines #L223 - L227 were not covered by tests
/// let n3 = 5;
/// for i in 0..n1 {
/// for j in 0..n2 {

Check warning on line 230 in mvpoly/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

mvpoly/src/utils.rs#L229-L230

Added lines #L229 - L230 were not covered by tests
/// for k in 0..n3 {
/// }
/// }
/// }
/// ```
/// the output will be all the possible values of `i`, `j`, and `k`.
/// The algorithm is as follows:
/// ```rust

Check warning on line 238 in mvpoly/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

mvpoly/src/utils.rs#L233-L238

Added lines #L233 - L238 were not covered by tests
/// let n1 = 3;
/// let n2 = 3;
/// let n3 = 5;

Check warning on line 241 in mvpoly/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

mvpoly/src/utils.rs#L240-L241

Added lines #L240 - L241 were not covered by tests
/// (0..(n1 * n2 * n3)).map(|l| {
/// let i = l % n1;
/// let j = (l / n1) % n2;
/// let k = (l / (n1 * n2)) % n3;
/// (i, j, k)
/// });
/// ```
/// For N nested loops, the algorithm is the same, with the division increasing
/// by the factor `N_k` for the index `i_(k + 1)`

Check warning on line 250 in mvpoly/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

mvpoly/src/utils.rs#L244-L250

Added lines #L244 - L250 were not covered by tests
pub fn compute_indices_nested_loop(nested_loop_sizes: Vec<usize>) -> Vec<Vec<usize>> {
let n = nested_loop_sizes.iter().product();
(0..n)
.map(|i| {
let mut div = 1;
// Compute indices for the loop, step i
let indices: Vec<usize> = nested_loop_sizes
.iter()
.map(|n_i| {
let k = (i / div) % n_i;
div *= n_i;
k
})
.collect();
assert!(
div == n,
"The division must be equal to the number of terms at the end"

Check warning on line 267 in mvpoly/src/utils.rs

View check run for this annotation

Codecov / codecov/patch

mvpoly/src/utils.rs#L267

Added line #L267 was not covered by tests
);
indices
})
.collect()
}
75 changes: 73 additions & 2 deletions mvpoly/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashMap;

use mvpoly::utils::{
compute_all_two_factors_decomposition, get_mapping_with_primes, is_prime, naive_prime_factors,
PrimeNumberGenerator,
compute_all_two_factors_decomposition, compute_indices_nested_loop, get_mapping_with_primes,
is_prime, naive_prime_factors, PrimeNumberGenerator,
};

pub const FIRST_FIFTY_PRIMES: [usize; 50] = [
Expand Down Expand Up @@ -225,3 +225,74 @@ pub fn test_compute_all_two_factors_decomposition_with_multiplicity() {
.to_vec()
);
}

#[test]
pub fn test_compute_indices_nested_loop() {
let nested_loops = vec![2, 2];
// sorting to get the same order
let mut exp_indices = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]];
exp_indices.sort();
let mut comp_indices = compute_indices_nested_loop(nested_loops);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);

let nested_loops = vec![3, 2];
// sorting to get the same order
let mut exp_indices = vec![
vec![0, 0],
vec![0, 1],
vec![1, 0],
vec![1, 1],
vec![2, 0],
vec![2, 1],
];
exp_indices.sort();
let mut comp_indices = compute_indices_nested_loop(nested_loops);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);

let nested_loops = vec![3, 3, 2, 2];
// sorting to get the same order
let mut exp_indices = vec![
vec![0, 0, 0, 0],
vec![0, 0, 0, 1],
vec![0, 0, 1, 0],
vec![0, 0, 1, 1],
vec![0, 1, 0, 0],
vec![0, 1, 0, 1],
vec![0, 1, 1, 0],
vec![0, 1, 1, 1],
vec![0, 2, 0, 0],
vec![0, 2, 0, 1],
vec![0, 2, 1, 0],
vec![0, 2, 1, 1],
vec![1, 0, 0, 0],
vec![1, 0, 0, 1],
vec![1, 0, 1, 0],
vec![1, 0, 1, 1],
vec![1, 1, 0, 0],
vec![1, 1, 0, 1],
vec![1, 1, 1, 0],
vec![1, 1, 1, 1],
vec![1, 2, 0, 0],
vec![1, 2, 0, 1],
vec![1, 2, 1, 0],
vec![1, 2, 1, 1],
vec![2, 0, 0, 0],
vec![2, 0, 0, 1],
vec![2, 0, 1, 0],
vec![2, 0, 1, 1],
vec![2, 1, 0, 0],
vec![2, 1, 0, 1],
vec![2, 1, 1, 0],
vec![2, 1, 1, 1],
vec![2, 2, 0, 0],
vec![2, 2, 0, 1],
vec![2, 2, 1, 0],
vec![2, 2, 1, 1],
];
exp_indices.sort();
let mut comp_indices = compute_indices_nested_loop(nested_loops);
comp_indices.sort();
assert_eq!(exp_indices, comp_indices);
}

0 comments on commit d11c543

Please sign in to comment.