diff --git a/fuzz/no_gpu_fuzz/Cargo.toml b/fuzz/no_gpu_fuzz/Cargo.toml index c745b1fb3..49d11a4d7 100644 --- a/fuzz/no_gpu_fuzz/Cargo.toml +++ b/fuzz/no_gpu_fuzz/Cargo.toml @@ -17,12 +17,21 @@ ibig = "0.3.6" p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3", rev = "41cd843" } p3-field = { git = "https://github.com/Plonky3/Plonky3", rev = "41cd843" } +p3-mersenne-31 = { git = "https://github.com/Plonky3/Plonky3", rev = "41cd843" } +p3-field = { git = "https://github.com/Plonky3/Plonky3", rev = "41cd843" } + [[bin]] name = "field_fuzzer" path = "fuzz_targets/field_fuzzer.rs" test = false doc = false +[[bin]] +name = "field_fuzz_mersenne31" +path = "fuzz_targets/field_mersenne31.rs" +test = false +doc = false + [[bin]] name = "field_mini_goldilocks" path = "fuzz_targets/field_mini_goldilocks.rs" diff --git a/fuzz/no_gpu_fuzz/fuzz_targets/field_mersenne31.rs b/fuzz/no_gpu_fuzz/fuzz_targets/field_mersenne31.rs new file mode 100644 index 000000000..1a6de4d60 --- /dev/null +++ b/fuzz/no_gpu_fuzz/fuzz_targets/field_mersenne31.rs @@ -0,0 +1,90 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use lambdaworks_math::field::{ + element::FieldElement, + fields::{ + mersenne31::field::{Mersenne31Field, MERSENNE_31_PRIME_FIELD_ORDER}, + } +}; +use p3_mersenne_31::Mersenne31; +use p3_field::{Field, PrimeField32, PrimeField64, AbstractField}; + +fuzz_target!(|values: (u32, u32)| { + // Note: we filter values outside of order as it triggers an assert within plonky3 disallowing values n >= Self::Order + if values.0 >= MERSENNE_31_PRIME_FIELD_ORDER || values.1 >= MERSENNE_31_PRIME_FIELD_ORDER { + return + } + + let (value_u32_a, value_u32_b) = values; + + let a = FieldElement::::from(value_u32_a as u64); + let b = FieldElement::::from(value_u32_b as u64); + + // Note: if we parse using from_canonical_u32 fails due to check that n < Self::Order + let a_expected = Mersenne31::from_canonical_u32(value_u32_a); + let b_expected = Mersenne31::from_canonical_u32(value_u32_b); + + let add_u32 = &a + &b; + let addition = a_expected + b_expected; + + assert_eq!(add_u32.representative(), addition.as_canonical_u32()); + + let sub_u32 = &a - &b; + let substraction = a_expected - b_expected; + assert_eq!(sub_u32.representative(), substraction.as_canonical_u32()); + + let mul_u32 = &a * &b; + let multiplication = a_expected * b_expected; + assert_eq!(mul_u32.representative(), multiplication.as_canonical_u32()); + + let pow = &a.pow(b.representative()); + let expected_pow = a_expected.exp_u64(b_expected.as_canonical_u64()); + assert_eq!(pow.representative(), expected_pow.as_canonical_u32()); + + if value_u32_b != 0 && b.inv().is_ok() && b_expected.try_inverse().is_some() { + let div = &a / &b; + assert_eq!(&div * &b, a.clone()); + let expected_div = a_expected / b_expected; + assert_eq!(div.representative(), expected_div.as_canonical_u32()); + } + + for n in [&a, &b] { + match n.sqrt() { + Some((fst_sqrt, snd_sqrt)) => { + assert_eq!(fst_sqrt.square(), snd_sqrt.square(), "Squared roots don't match each other"); + assert_eq!(n, &fst_sqrt.square(), "Squared roots don't match original number"); + } + None => {} + }; + } + + // Axioms soundness + + let one = FieldElement::::one(); + let zero = FieldElement::::zero(); + + assert_eq!(&a + &zero, a, "Neutral add element a failed"); + assert_eq!(&b + &zero, b, "Neutral mul element b failed"); + assert_eq!(&a * &one, a, "Neutral add element a failed"); + assert_eq!(&b * &one, b, "Neutral mul element b failed"); + + assert_eq!(&a + &b, &b + &a, "Commutative add property failed"); + assert_eq!(&a * &b, &b * &a, "Commutative mul property failed"); + + let c = &a * &b; + assert_eq!((&a + &b) + &c, &a + (&b + &c), "Associative add property failed"); + assert_eq!((&a * &b) * &c, &a * (&b * &c), "Associative mul property failed"); + + assert_eq!(&a * (&b + &c), &a * &b + &a * &c, "Distributive property failed"); + + assert_eq!(&a - &a, zero, "Inverse add a failed"); + assert_eq!(&b - &b, zero, "Inverse add b failed"); + + if a != zero { + assert_eq!(&a * a.inv().unwrap(), one, "Inverse mul a failed"); + } + if b != zero { + assert_eq!(&b * b.inv().unwrap(), one, "Inverse mul b failed"); + } +}); diff --git a/math/benches/criterion_field.rs b/math/benches/criterion_field.rs index 402d01dc9..1c21822de 100644 --- a/math/benches/criterion_field.rs +++ b/math/benches/criterion_field.rs @@ -2,6 +2,8 @@ use criterion::{criterion_group, criterion_main, Criterion}; use pprof::criterion::{Output, PProfProfiler}; mod fields; +use fields::mersenne31::mersenne31_ops_benchmarks; +use fields::mersenne31_montgomery::mersenne31_mont_ops_benchmarks; use fields::{ stark252::starkfield_ops_benchmarks, u64_goldilocks::u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery::u64_goldilocks_montgomery_ops_benchmarks, @@ -10,6 +12,6 @@ use fields::{ criterion_group!( name = field_benches; config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = starkfield_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks + targets = starkfield_ops_benchmarks, mersenne31_ops_benchmarks, mersenne31_mont_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks ); criterion_main!(field_benches); diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs new file mode 100644 index 000000000..99e3921a5 --- /dev/null +++ b/math/benches/fields/mersenne31.rs @@ -0,0 +1,195 @@ +use std::hint::black_box; + +use criterion::Criterion; +use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use rand::random; + +pub type F = FieldElement; + +#[inline(never)] +#[no_mangle] +#[export_name = "util::rand_mersenne31_field_elements"] +pub fn rand_field_elements(num: usize) -> Vec<(F, F)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + result.push((F::new(random()), F::new(random()))); + } + result +} + +pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { + let input: Vec> = [1, 10, 100, 1000, 10000, 100000, 1000000] + .into_iter() + .map(rand_field_elements) + .collect::>(); + let mut group = c.benchmark_group("Mersenne31 operations"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) + black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("mul {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("pow by 1 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(1_u64)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square with pow {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(2_u64)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square with mul {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x) * black_box(x)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input( + format!("pow {:?}", &i.len()), + &(i, 5u64), + |bench, (i, a)| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(*a)); + } + }); + }, + ); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) - black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("div {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) / black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) == black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sqrt {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).sqrt()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sqrt squared {:?}", &i.len()), &i, |bench, i| { + let i: Vec = i.iter().map(|(x, _)| x * x).collect(); + bench.iter(|| { + for x in &i { + black_box(black_box(x).sqrt()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitand {:?}", &i.len()), &i, |bench, i| { + // Note: we should strive to have the number of limbs be generic... ideally this benchmark group itself should have a generic type that we call into from the main runner. + let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) & black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitor {:?}", &i.len()), &i, |bench, i| { + let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) | black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitxor {:?}", &i.len()), &i, |bench, i| { + let i: Vec<(u32, u32)> = i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) ^ black_box(*y)); + } + }); + }); + } +} diff --git a/math/benches/fields/mersenne31_montgomery.rs b/math/benches/fields/mersenne31_montgomery.rs new file mode 100644 index 000000000..a3298a0d1 --- /dev/null +++ b/math/benches/fields/mersenne31_montgomery.rs @@ -0,0 +1,231 @@ +use std::hint::black_box; + +use criterion::Criterion; +use lambdaworks_math::{ + field::{ + element::FieldElement, + fields::{ + fft_friendly::u64_mersenne_montgomery_field::{ + Mersenne31MontgomeryPrimeField, MontgomeryConfigMersenne31PrimeField, + }, + montgomery_backed_prime_fields::IsModulus, + }, + }, + unsigned_integer::{ + element::{UnsignedInteger, U64}, + montgomery::MontgomeryAlgorithms, + }, +}; +use rand::random; + +pub type F = FieldElement; +const NUM_LIMBS: usize = 1; + +#[inline(never)] +#[no_mangle] +#[export_name = "util::rand_mersenne31_mont_field_elements"] +pub fn rand_field_elements(num: usize) -> Vec<(F, F)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + let rand_a = UnsignedInteger { limbs: random() }; + let rand_b = UnsignedInteger { limbs: random() }; + result.push((F::new(rand_a), F::new(rand_b))); + } + result +} + +pub fn mersenne31_mont_ops_benchmarks(c: &mut Criterion) { + let input: Vec> = [1, 10, 100, 1000, 10000, 100000, 1000000] + .into_iter() + .map(rand_field_elements) + .collect::>(); + let mut group = c.benchmark_group("Mersenne31 Mont operations"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("add {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) + black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("mul {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("pow by 1 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(1_u64)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square with pow {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(2_u64)); + } + }); + }); + } + + // The non-boxed constants are intentional as they are + // normally computed at compile time. + for i in input.clone().into_iter() { + group.bench_with_input(format!("sos_square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + MontgomeryAlgorithms::sos_square( + black_box(black_box(x.value())), + &>::MODULUS, + &Mersenne31MontgomeryPrimeField::MU, + ); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("square with mul {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x) * black_box(x)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input( + format!("pow {:?}", &i.len()), + &(i, 5u64), + |bench, (i, a)| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).pow(*a)); + } + }); + }, + ); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sub {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) - black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("inv {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("div {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) / black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("eq {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) == black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sqrt {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).sqrt()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("sqrt squared {:?}", &i.len()), &i, |bench, i| { + let i: Vec = i.iter().map(|(x, _)| x * x).collect(); + bench.iter(|| { + for x in &i { + black_box(black_box(x).sqrt()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitand {:?}", &i.len()), &i, |bench, i| { + // Note: we should strive to have the number of limbs be generic... ideally this benchmark group itself should have a generic type that we call into from the main runner. + let i: Vec<(UnsignedInteger, UnsignedInteger)> = + i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) & black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitor {:?}", &i.len()), &i, |bench, i| { + let i: Vec<(UnsignedInteger, UnsignedInteger)> = + i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) | black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("bitxor {:?}", &i.len()), &i, |bench, i| { + let i: Vec<(UnsignedInteger, UnsignedInteger)> = + i.iter().map(|(x, y)| (*x.value(), *y.value())).collect(); + bench.iter(|| { + for (x, y) in &i { + black_box(black_box(*x) ^ black_box(*y)); + } + }); + }); + } +} diff --git a/math/benches/fields/mod.rs b/math/benches/fields/mod.rs index 5a545ddcf..a28773c6c 100644 --- a/math/benches/fields/mod.rs +++ b/math/benches/fields/mod.rs @@ -1,3 +1,5 @@ +pub mod mersenne31; +pub mod mersenne31_montgomery; pub mod stark252; pub mod u64_goldilocks; pub mod u64_goldilocks_montgomery; diff --git a/math/src/field/fields/fft_friendly/mod.rs b/math/src/field/fields/fft_friendly/mod.rs index d5c27cb92..92b6a5286 100644 --- a/math/src/field/fields/fft_friendly/mod.rs +++ b/math/src/field/fields/fft_friendly/mod.rs @@ -4,3 +4,5 @@ pub mod babybear; pub mod stark_252_prime_field; /// Implemenation of the Goldilocks Prime Field p = 2^64 - 2^32 + 1 pub mod u64_goldilocks; +/// Implemenation of the Mersenne Prime field p = 2^31 - 1 +pub mod u64_mersenne_montgomery_field; diff --git a/math/src/field/fields/fft_friendly/stark_252_prime_field.rs b/math/src/field/fields/fft_friendly/stark_252_prime_field.rs index 7f0fb6c43..737a21014 100644 --- a/math/src/field/fields/fft_friendly/stark_252_prime_field.rs +++ b/math/src/field/fields/fft_friendly/stark_252_prime_field.rs @@ -83,7 +83,7 @@ impl FieldElement { } } -#[allow(clippy::incorrect_partial_ord_impl_on_ord_type)] +#[allow(clippy::non_canonical_partial_ord_impl)] impl PartialOrd for FieldElement { fn partial_cmp(&self, other: &Self) -> Option { self.representative().partial_cmp(&other.representative()) diff --git a/math/src/field/fields/fft_friendly/u64_goldilocks.rs b/math/src/field/fields/fft_friendly/u64_goldilocks.rs index 09e31fc89..7b35e4405 100644 --- a/math/src/field/fields/fft_friendly/u64_goldilocks.rs +++ b/math/src/field/fields/fft_friendly/u64_goldilocks.rs @@ -30,7 +30,7 @@ impl FieldElement { } } -#[allow(clippy::incorrect_partial_ord_impl_on_ord_type)] +#[allow(clippy::non_canonical_partial_ord_impl)] impl PartialOrd for FieldElement { fn partial_cmp(&self, other: &Self) -> Option { self.representative().partial_cmp(&other.representative()) diff --git a/math/src/field/fields/fft_friendly/u64_mersenne_montgomery_field.rs b/math/src/field/fields/fft_friendly/u64_mersenne_montgomery_field.rs new file mode 100644 index 000000000..a7521e4a6 --- /dev/null +++ b/math/src/field/fields/fft_friendly/u64_mersenne_montgomery_field.rs @@ -0,0 +1,44 @@ +use crate::{ + field::{ + element::FieldElement, + fields::montgomery_backed_prime_fields::{IsModulus, MontgomeryBackendPrimeField}, + }, + unsigned_integer::element::U64, +}; + +pub type U64MontgomeryBackendPrimeField = MontgomeryBackendPrimeField; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct MontgomeryConfigMersenne31PrimeField; +impl IsModulus for MontgomeryConfigMersenne31PrimeField { + //Mersenne Prime p = 2^31 - 1 + const MODULUS: U64 = U64::from_u64(2147483647); +} + +pub type Mersenne31MontgomeryPrimeField = + U64MontgomeryBackendPrimeField; + +impl FieldElement { + pub fn to_bytes_le(&self) -> [u8; 8] { + let limbs = self.representative().limbs; + limbs[0].to_le_bytes() + } + + pub fn to_bytes_be(&self) -> [u8; 8] { + let limbs = self.representative().limbs; + limbs[0].to_be_bytes() + } +} + +#[allow(clippy::non_canonical_partial_ord_impl)] +impl PartialOrd for FieldElement { + fn partial_cmp(&self, other: &Self) -> Option { + self.representative().partial_cmp(&other.representative()) + } +} + +impl Ord for FieldElement { + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.representative().cmp(&other.representative()) + } +} diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs new file mode 100644 index 000000000..e8773e8d5 --- /dev/null +++ b/math/src/field/fields/mersenne31/extension.rs @@ -0,0 +1,306 @@ +use crate::field::{ + element::FieldElement, + errors::FieldError, + extensions::{ + cubic::{CubicExtensionField, HasCubicNonResidue}, + quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, + }, + traits::IsField, +}; + +use super::field::Mersenne31Field; + +//Note: The inverse calculation in mersenne31/plonky3 differs from the default quadratic extension so I implemented the complex extension. +////////////////// +#[derive(Clone, Debug)] +pub struct Mersenne31Complex; + +impl IsField for Mersenne31Complex { + //Elements represents a[0] = real, a[1] = imaginary + type BaseType = [FieldElement; 2]; + + /// Returns the component wise addition of `a` and `b` + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [a[0] + b[0], a[1] + b[1]] + } + + //NOTE: THIS uses Gauss algorithm. Bench this against plonky 3 implementation to see what is faster. + /// Returns the multiplication of `a` and `b` using the following + /// equation: + /// (a0 + a1 * t) * (b0 + b1 * t) = a0 * b0 + a1 * b1 * Self::residue() + (a0 * b1 + a1 * b0) * t + /// where `t.pow(2)` equals `Q::residue()`. + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + let a0b0 = a[0] * b[0]; + let a1b1 = a[1] * b[1]; + let z = (a[0] + a[1]) * (b[0] + b[1]); + [a0b0 - a1b1, z - a0b0 - a1b1] + } + + fn square(a: &Self::BaseType) -> Self::BaseType { + let [a0, a1] = a; + let v0 = a0 * a1; + let c0 = (a0 + a1) * (a0 - a1); + let c1 = v0 + v0; + [c0, c1] + } + /// Returns the component wise subtraction of `a` and `b` + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [a[0] - b[0], a[1] - b[1]] + } + + /// Returns the component wise negation of `a` + fn neg(a: &Self::BaseType) -> Self::BaseType { + [-a[0], -a[1]] + } + + /// Returns the multiplicative inverse of `a` + fn inv(a: &Self::BaseType) -> Result { + let inv_norm = (a[0].pow(2_u64) + a[1].pow(2_u64)).inv()?; + Ok([a[0] * inv_norm, -a[1] * inv_norm]) + } + + /// Returns the division of `a` and `b` + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + Self::mul(a, &Self::inv(b).unwrap()) + } + + /// Returns a boolean indicating whether `a` and `b` are equal component wise. + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + a[0] == b[0] && a[1] == b[1] + } + + /// Returns the additive neutral element of the field extension. + fn zero() -> Self::BaseType { + [FieldElement::zero(), FieldElement::zero()] + } + + /// Returns the multiplicative neutral element of the field extension. + fn one() -> Self::BaseType { + [FieldElement::one(), FieldElement::zero()] + } + + /// Returns the element `x * 1` where 1 is the multiplicative neutral element. + fn from_u64(x: u64) -> Self::BaseType { + [FieldElement::from(x), FieldElement::zero()] + } + + /// Takes as input an element of BaseType and returns the internal representation + /// of that element in the field. + /// Note: for this case this is simply the identity, because the components + /// already have correct representations. + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + x + } +} + +pub type Mersenne31ComplexQuadraticExtensionField = QuadraticExtensionField; + +//TODO: Check this should be for complex and not base field +impl HasQuadraticNonResidue for Mersenne31Complex { + type BaseField = Mersenne31Complex; + + // Verifiable in Sage with + // ```sage + // p = 2**31 - 1 # Mersenne31 + // F = GF(p) # The base field GF(p) + // R. = F[] # The polynomial ring over F + // K. = F.extension(x^2 + 1) # The complex extension field + // R2. = K[] + // f2 = y^2 - i - 2 + // assert f2.is_irreducible() + // ``` + fn residue() -> FieldElement { + FieldElement::from(&Mersenne31Complex::from_base_type([ + FieldElement::::from(2), + FieldElement::::one(), + ])) + } +} + +pub type Mersenne31ComplexCubicExtensionField = CubicExtensionField; + +impl HasCubicNonResidue for Mersenne31Complex { + type BaseField = Mersenne31Complex; + + // Verifiable in Sage with + // ```sage + // p = 2**31 - 1 # Mersenne31 + // F = GF(p) # The base field GF(p) + // R. = F[] # The polynomial ring over F + // K. = F.extension(x^2 + 1) # The complex extension field + // R2. = K[] + // f2 = y^3 - 5*i + // assert f2.is_irreducible() + // ``` + fn residue() -> FieldElement { + FieldElement::from(&Mersenne31Complex::from_base_type([ + FieldElement::::zero(), + FieldElement::::from(5), + ])) + } +} + +#[cfg(test)] +mod tests { + use crate::field::fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER; + + use super::*; + + type Fi = Mersenne31Complex; + type F = FieldElement; + + //NOTE: from_u64 reflects from_real + //NOTE: for imag use from_base_type + + #[test] + fn add_real_one_plus_one_is_two() { + assert_eq!(Fi::add(&Fi::one(), &Fi::one()), Fi::from_u64(2)) + } + + #[test] + fn add_real_neg_one_plus_one_is_zero() { + assert_eq!(Fi::add(&Fi::neg(&Fi::one()), &Fi::one()), Fi::zero()) + } + + #[test] + fn add_real_neg_one_plus_two_is_one() { + assert_eq!(Fi::add(&Fi::neg(&Fi::one()), &Fi::from_u64(2)), Fi::one()) + } + + #[test] + fn add_real_neg_one_plus_neg_one_is_order_sub_two() { + assert_eq!( + Fi::add(&Fi::neg(&Fi::one()), &Fi::neg(&Fi::one())), + Fi::from_u64((MERSENNE_31_PRIME_FIELD_ORDER - 2).into()) + ) + } + + #[test] + fn add_complex_one_plus_one_two() { + //Manually declare the complex part to one + let one = Fi::from_base_type([F::zero(), F::one()]); + let two = Fi::from_base_type([F::zero(), F::from(2)]); + assert_eq!(Fi::add(&one, &one), two) + } + + #[test] + fn add_complex_neg_one_plus_one_is_zero() { + //Manually declare the complex part to one + let neg_one = Fi::from_base_type([F::zero(), -F::one()]); + let one = Fi::from_base_type([F::zero(), F::one()]); + assert_eq!(Fi::add(&neg_one, &one), Fi::zero()) + } + + #[test] + fn add_complex_neg_one_plus_two_is_one() { + let neg_one = Fi::from_base_type([F::zero(), -F::one()]); + let two = Fi::from_base_type([F::zero(), F::from(2)]); + let one = Fi::from_base_type([F::zero(), F::one()]); + assert_eq!(Fi::add(&neg_one, &two), one) + } + + #[test] + fn add_complex_neg_one_plus_neg_one_imag_is_order_sub_two() { + let neg_one = Fi::from_base_type([F::zero(), -F::one()]); + assert_eq!( + Fi::add(&neg_one, &neg_one)[1], + F::new(MERSENNE_31_PRIME_FIELD_ORDER - 2) + ) + } + + #[test] + fn add_order() { + let a = Fi::from_base_type([-F::one(), F::one()]); + let b = Fi::from_base_type([F::from(2), F::new(MERSENNE_31_PRIME_FIELD_ORDER - 2)]); + let c = Fi::from_base_type([F::one(), -F::one()]); + assert_eq!(Fi::add(&a, &b), c) + } + + #[test] + fn add_equal_zero() { + let a = Fi::from_base_type([-F::one(), -F::one()]); + let b = Fi::from_base_type([F::one(), F::one()]); + assert_eq!(Fi::add(&a, &b), Fi::zero()) + } + + #[test] + fn add_plus_one() { + let a = Fi::from_base_type([F::one(), F::from(2)]); + let b = Fi::from_base_type([F::one(), F::one()]); + let c = Fi::from_base_type([F::from(2), F::from(3)]); + assert_eq!(Fi::add(&a, &b), c) + } + + #[test] + fn sub_real_one_sub_one_is_zero() { + assert_eq!(Fi::sub(&Fi::one(), &Fi::one()), Fi::zero()) + } + + #[test] + fn sub_real_two_sub_two_is_zero() { + assert_eq!( + Fi::sub(&Fi::from_u64(2u64), &Fi::from_u64(2u64)), + Fi::zero() + ) + } + + #[test] + fn sub_real_neg_one_sub_neg_one_is_zero() { + assert_eq!( + Fi::sub(&Fi::neg(&Fi::one()), &Fi::neg(&Fi::one())), + Fi::zero() + ) + } + + #[test] + fn sub_real_two_sub_one_is_one() { + assert_eq!(Fi::sub(&Fi::from_u64(2), &Fi::one()), Fi::one()) + } + + #[test] + fn sub_real_neg_one_sub_zero_is_neg_one() { + assert_eq!( + Fi::sub(&Fi::neg(&Fi::one()), &Fi::zero()), + Fi::neg(&Fi::one()) + ) + } + + #[test] + fn sub_complex_one_sub_one_is_zero() { + let one = Fi::from_base_type([F::zero(), F::one()]); + assert_eq!(Fi::sub(&one, &one), Fi::zero()) + } + + #[test] + fn sub_complex_two_sub_two_is_zero() { + let two = Fi::from_base_type([F::zero(), F::from(2)]); + assert_eq!(Fi::sub(&two, &two), Fi::zero()) + } + + #[test] + fn sub_complex_neg_one_sub_neg_one_is_zero() { + let neg_one = Fi::from_base_type([F::zero(), -F::one()]); + assert_eq!(Fi::sub(&neg_one, &neg_one), Fi::zero()) + } + + #[test] + fn sub_complex_two_sub_one_is_one() { + let two = Fi::from_base_type([F::zero(), F::from(2)]); + let one = Fi::from_base_type([F::zero(), F::one()]); + assert_eq!(Fi::sub(&two, &one), one) + } + + #[test] + fn sub_complex_neg_one_sub_zero_is_neg_one() { + let neg_one = Fi::from_base_type([F::zero(), -F::one()]); + assert_eq!(Fi::sub(&neg_one, &Fi::zero()), neg_one) + } + + #[test] + fn mul() { + let a = Fi::from_base_type([F::from(2), F::from(2)]); + let b = Fi::from_base_type([F::from(4), F::from(5)]); + let c = Fi::from_base_type([-F::from(2), F::from(18)]); + assert_eq!(Fi::mul(&a, &b), c) + } +} diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs new file mode 100644 index 000000000..7a07597f9 --- /dev/null +++ b/math/src/field/fields/mersenne31/field.rs @@ -0,0 +1,402 @@ +use crate::{ + errors::CreationError, + field::{ + element::FieldElement, + errors::FieldError, + traits::{IsField, IsPrimeField}, + }, +}; +use core::fmt::{self, Display}; + +/// Represents a 31 bit integer value +/// Invariants: +/// 31st bit is clear +/// n < MODULUS +#[derive(Debug, Clone, Copy, Hash, PartialOrd, Ord, PartialEq, Eq)] +pub struct Mersenne31Field; + +impl Mersenne31Field { + fn weak_reduce(n: u32) -> u32 { + // To reduce 'n' to 31 bits we clear its MSB, then add it back in its reduced form. + let msb = n & (1 << 31); + let msb_reduced = msb >> 31; + let res = msb ^ n; + + // assert msb_reduced fits within 31 bits + debug_assert!((res >> 31) == 0 && (msb_reduced >> 1) == 0); + res + msb_reduced + } + + fn as_representative(n: &u32) -> u32 { + if *n == MERSENNE_31_PRIME_FIELD_ORDER { + 0 + } else { + *n + } + } +} + +pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; + +//NOTE: This implementation was inspired by and borrows from the work done by the Plonky3 team +// https://github.com/Plonky3/Plonky3/blob/main/mersenne-31/src/lib.rs +// Thank you for pushing this technology forward. +impl IsField for Mersenne31Field { + type BaseType = u32; + + /// Returns the sum of `a` and `b`. + fn add(a: &u32, b: &u32) -> u32 { + // Avoids conditional https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249 + // Working with i32 means we get a flag which informs us if overflow happens + let (sum_i32, over) = (*a as i32).overflowing_add(*b as i32); + let sum_u32 = sum_i32 as u32; + let sum_corr = sum_u32.wrapping_sub(MERSENNE_31_PRIME_FIELD_ORDER); + + //assert 31 bit clear + // If self + rhs did not overflow, return it. + // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1). + let sum = if over { sum_corr } else { sum_u32 }; + debug_assert!((sum >> 31) == 0); + Self::as_representative(&sum) + } + + /// Returns the multiplication of `a` and `b`. + // Note: for powers of 2 we can perform bit shifting this would involve overriding the trait implementation + fn mul(a: &u32, b: &u32) -> u32 { + let prod = u64::from(*a) * u64::from(*b); + let prod_lo = (prod as u32) & ((1 << 31) - 1); + let prod_hi = (prod >> 31) as u32; + //assert prod_hi and prod_lo 31 bit clear + debug_assert!((prod_lo >> 31) == 0 && (prod_hi >> 31) == 0); + Self::add(&prod_lo, &prod_hi) + } + + fn sub(a: &u32, b: &u32) -> u32 { + let (mut sub, over) = a.overflowing_sub(*b); + + // If we didn't overflow we have the correct value. + // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1. + // Hence we need to remove the most significant bit and subtract 1. + sub -= over as u32; + sub & MERSENNE_31_PRIME_FIELD_ORDER + } + + /// Returns the additive inverse of `a`. + fn neg(a: &u32) -> u32 { + // NOTE: MODULUS known to have 31 bit clear + MERSENNE_31_PRIME_FIELD_ORDER - a + } + + /// Returns the multiplicative inverse of `a`. + fn inv(a: &u32) -> Result { + if *a == Self::zero() || *a == MERSENNE_31_PRIME_FIELD_ORDER { + return Err(FieldError::InvZeroError); + } + let p101 = Self::mul(&Self::pow(a, 4u32), a); + let p1111 = Self::mul(&Self::square(&p101), &p101); + let p11111111 = Self::mul(&Self::pow(&p1111, 16u32), &p1111); + let p111111110000 = Self::pow(&p11111111, 16u32); + let p111111111111 = Self::mul(&p111111110000, &p1111); + let p1111111111111111 = Self::mul(&Self::pow(&p111111110000, 16u32), &p11111111); + let p1111111111111111111111111111 = + Self::mul(&Self::pow(&p1111111111111111, 4096u32), &p111111111111); + let p1111111111111111111111111111101 = + Self::mul(&Self::pow(&p1111111111111111111111111111, 8u32), &p101); + Ok(p1111111111111111111111111111101) + } + + /// Returns the division of `a` and `b`. + fn div(a: &u32, b: &u32) -> u32 { + let b_inv = Self::inv(b).expect("InvZeroError"); + Self::mul(a, &b_inv) + } + + /// Returns a boolean indicating whether `a` and `b` are equal or not. + fn eq(a: &u32, b: &u32) -> bool { + Self::as_representative(a) == Self::representative(b) + } + + /// Returns the additive neutral element. + fn zero() -> Self::BaseType { + 0u32 + } + + /// Returns the multiplicative neutral element. + fn one() -> u32 { + 1u32 + } + + /// Returns the element `x * 1` where 1 is the multiplicative neutral element. + fn from_u64(x: u64) -> u32 { + let (lo, hi) = (x as u32 as u64, x >> 32); + // 2^32 = 2 (mod Mersenne 31 bit prime) + // t <= (2^32 - 1) + 2 * (2^32 - 1) = 3 * 2^32 - 3 = 6 * 2^31 - 3 + let t = lo + 2 * hi; + + const MASK: u64 = (1 << 31) - 1; + let (lo, hi) = ((t & MASK) as u32, (t >> 31) as u32); + // 2^31 = 1 mod Mersenne31 + // lo < 2^31, hi < 6, so lo + hi < 2^32. + Self::weak_reduce(lo + hi) + } + + /// Takes as input an element of BaseType and returns the internal representation + /// of that element in the field. + fn from_base_type(x: u32) -> u32 { + Self::weak_reduce(x) + } +} + +impl IsPrimeField for Mersenne31Field { + type RepresentativeType = u32; + + // Since our invariant guarantees that `value` fits in 31 bits, there is only one possible value + // `value` that is not canonical, namely 2^31 - 1 = p = 0. + fn representative(x: &u32) -> u32 { + debug_assert!((x >> 31) == 0); + Self::as_representative(x) + } + + fn field_bit_size() -> usize { + ((MERSENNE_31_PRIME_FIELD_ORDER - 1).ilog2() + 1) as usize + } + + fn from_hex(hex_string: &str) -> Result { + let mut hex_string = hex_string; + // Remove 0x if it's on the string + let mut char_iterator = hex_string.chars(); + if hex_string.len() > 2 + && char_iterator.next().unwrap() == '0' + && char_iterator.next().unwrap() == 'x' + { + hex_string = &hex_string[2..]; + } + u32::from_str_radix(hex_string, 16).map_err(|_| CreationError::InvalidHexString) + } +} + +impl FieldElement { + #[cfg(feature = "std")] + pub fn to_bytes_le(&self) -> Vec { + self.representative().to_le_bytes().to_vec() + } + + #[cfg(feature = "std")] + pub fn to_bytes_be(&self) -> Vec { + self.representative().to_be_bytes().to_vec() + } +} + +impl Display for FieldElement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:x}", self.representative())?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + type F = Mersenne31Field; + + #[test] + fn from_hex_for_b_is_11() { + assert_eq!(F::from_hex("B").unwrap(), 11); + } + + #[test] + fn from_hex_for_0x1_a_is_26() { + assert_eq!(F::from_hex("0x1a").unwrap(), 26); + } + + #[test] + fn bit_size_of_field_is_31() { + assert_eq!( + ::field_bit_size(), + 31 + ); + } + + #[test] + fn one_plus_1_is_2() { + let a = F::one(); + let b = F::one(); + let c = F::add(&a, &b); + assert_eq!(c, 2u32); + } + + #[test] + fn neg_1_plus_1_is_0() { + let a = F::neg(&F::one()); + let b = F::one(); + let c = F::add(&a, &b); + assert_eq!(c, F::zero()); + } + + #[test] + fn neg_1_plus_2_is_1() { + let a = F::neg(&F::one()); + let b = F::from_base_type(2u32); + let c = F::add(&a, &b); + assert_eq!(c, F::one()); + } + + #[test] + fn max_order_plus_1_is_0() { + let a = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); + let b = F::one(); + let c = F::add(&a, &b); + assert_eq!(c, F::zero()); + } + + #[test] + fn comparing_13_and_13_are_equal() { + let a = F::from_base_type(13); + let b = F::from_base_type(13); + assert_eq!(a, b); + } + + #[test] + fn comparing_13_and_8_they_are_not_equal() { + let a = F::from_base_type(13); + let b = F::from_base_type(8); + assert_ne!(a, b); + } + + #[test] + fn one_sub_1_is_0() { + let a = F::one(); + let b = F::one(); + let c = F::sub(&a, &b); + assert_eq!(c, F::zero()); + } + + #[test] + fn zero_sub_1_is_order_minus_1() { + let a = F::zero(); + let b = F::one(); + let c = F::sub(&a, &b); + assert_eq!(c, MERSENNE_31_PRIME_FIELD_ORDER - 1); + } + + #[test] + fn neg_1_sub_neg_1_is_0() { + let a = F::neg(&F::one()); + let b = F::neg(&F::one()); + let c = F::sub(&a, &b); + assert_eq!(c, F::zero()); + } + + #[test] + fn neg_1_sub_1_is_neg_1() { + let a = F::neg(&F::one()); + let b = F::zero(); + let c = F::sub(&a, &b); + assert_eq!(c, F::neg(&F::one())); + } + + #[test] + fn mul_neutral_element() { + let a = F::from_base_type(1); + let b = F::from_base_type(2); + let c = F::mul(&a, &b); + assert_eq!(c, F::from_base_type(2)); + } + + #[test] + fn mul_2_3_is_6() { + let a = F::from_base_type(2); + let b = F::from_base_type(3); + assert_eq!(a * b, F::from_base_type(6)); + } + + #[test] + fn mul_order_neg_1() { + let a = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); + let b = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); + let c = F::mul(&a, &b); + assert_eq!(c, F::from_base_type(1)); + } + + #[test] + fn pow_p_neg_1() { + assert_eq!( + F::pow(&F::from_base_type(2), MERSENNE_31_PRIME_FIELD_ORDER - 1), + F::one() + ) + } + + #[test] + fn inv_0_error() { + let result = F::inv(&F::zero()); + assert!(matches!(result, Err(FieldError::InvZeroError))); + } + + #[test] + fn inv_2() { + let result = F::inv(&F::from_base_type(2u32)).unwrap(); + // sage: 1 / F(2) = 1073741824 + assert_eq!(result, 1073741824); + } + + #[test] + fn pow_2_3() { + assert_eq!(F::pow(&F::from_base_type(2), 3_u64), 8) + } + + #[test] + fn div_1() { + assert_eq!(F::div(&F::from_base_type(2), &F::from_base_type(1)), 2) + } + + #[test] + fn div_4_2() { + assert_eq!(F::div(&F::from_base_type(4), &F::from_base_type(2)), 2) + } + + // 1431655766 + #[test] + fn div_4_3() { + // sage: F(4) / F(3) = 1431655766 + assert_eq!( + F::div(&F::from_base_type(4), &F::from_base_type(3)), + 1431655766 + ) + } + + #[test] + fn two_plus_its_additive_inv_is_0() { + let two = F::from_base_type(2); + + assert_eq!(F::add(&two, &F::neg(&two)), F::zero()) + } + + #[test] + fn from_u64_test() { + let num = F::from_u64(1u64); + assert_eq!(num, F::one()); + } + + #[test] + fn creating_a_field_element_from_its_representative_returns_the_same_element_1() { + let change = 1; + let f1 = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER + change); + let f2 = F::from_base_type(Mersenne31Field::representative(&f1)); + assert_eq!(f1, f2); + } + + #[test] + fn creating_a_field_element_from_its_representative_returns_the_same_element_2() { + let change = 8; + let f1 = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER + change); + let f2 = F::from_base_type(Mersenne31Field::representative(&f1)); + assert_eq!(f1, f2); + } + + #[test] + fn from_base_type_test() { + let b = F::from_base_type(1u32); + assert_eq!(b, F::one()); + } +} diff --git a/math/src/field/fields/mersenne31/mod.rs b/math/src/field/fields/mersenne31/mod.rs new file mode 100644 index 000000000..4bfd3daf1 --- /dev/null +++ b/math/src/field/fields/mersenne31/mod.rs @@ -0,0 +1,2 @@ +pub mod extension; +pub mod field; diff --git a/math/src/field/fields/mod.rs b/math/src/field/fields/mod.rs index c857a6880..33c39d816 100644 --- a/math/src/field/fields/mod.rs +++ b/math/src/field/fields/mod.rs @@ -1,5 +1,7 @@ /// Implementation of two-adic prime fields to use with the Fast Fourier Transform (FFT). pub mod fft_friendly; +/// Implementation of the 32-bit Mersenne Prime field (p = 2^31 - 1) +pub mod mersenne31; pub mod montgomery_backed_prime_fields; /// Implementation of the Goldilocks Prime field (p = 2^448 - 2^224 - 1) pub mod p448_goldilocks_prime_field; diff --git a/math/src/unsigned_integer/element.rs b/math/src/unsigned_integer/element.rs index fab80ce6f..70e111854 100644 --- a/math/src/unsigned_integer/element.rs +++ b/math/src/unsigned_integer/element.rs @@ -36,7 +36,7 @@ pub struct UnsignedInteger { // NOTE: manually implementing `PartialOrd` may seem unorthodox, but the // derived implementation had terrible performance. -#[allow(clippy::incorrect_partial_ord_impl_on_ord_type)] +#[allow(clippy::non_canonical_partial_ord_impl)] impl PartialOrd for UnsignedInteger { fn partial_cmp(&self, other: &Self) -> Option { let mut i = 0;