diff --git a/ec/src/hashing/map_to_curve_hasher.rs b/ec/src/hashing/map_to_curve_hasher.rs index fa50a5e57..e0f46f323 100644 --- a/ec/src/hashing/map_to_curve_hasher.rs +++ b/ec/src/hashing/map_to_curve_hasher.rs @@ -56,7 +56,7 @@ where // 5. P = clear_cofactor(R) // 6. return P - let rand_field_elems = self.field_hasher.hash_to_field(msg, 2); + let rand_field_elems = self.field_hasher.hash_to_field::<2>(msg); let rand_curve_elem_0 = M2C::map_to_curve(rand_field_elems[0])?; let rand_curve_elem_1 = M2C::map_to_curve(rand_field_elems[1])?; diff --git a/ff/Cargo.toml b/ff/Cargo.toml index b201b086b..bfbf07798 100644 --- a/ff/Cargo.toml +++ b/ff/Cargo.toml @@ -19,6 +19,7 @@ ark-ff-asm.workspace = true ark-ff-macros.workspace = true ark-std.workspace = true ark-serialize.workspace = true +arrayvec = { version = "0.7", default-features = false } derivative = { workspace = true, features = ["use_core"] } num-traits.workspace = true paste.workspace = true diff --git a/ff/src/fields/field_hashers/expander/mod.rs b/ff/src/fields/field_hashers/expander/mod.rs index 8b1ef0a12..16cc17df5 100644 --- a/ff/src/fields/field_hashers/expander/mod.rs +++ b/ff/src/fields/field_hashers/expander/mod.rs @@ -1,99 +1,119 @@ // The below implementation is a rework of https://github.com/armfazh/h2c-rust-ref // With some optimisations +use core::marker::PhantomData; + use ark_std::vec::Vec; -use digest::{DynDigest, ExtendableOutput, Update}; + +use arrayvec::ArrayVec; +use digest::{ExtendableOutput, FixedOutputReset, Update}; + pub trait Expander { - fn construct_dst_prime(&self) -> Vec; fn expand(&self, msg: &[u8], length: usize) -> Vec; } const MAX_DST_LENGTH: usize = 255; -const LONG_DST_PREFIX: [u8; 17] = [ - //'H', '2', 'C', '-', 'O', 'V', 'E', 'R', 'S', 'I', 'Z', 'E', '-', 'D', 'S', 'T', '-', - 0x48, 0x32, 0x43, 0x2d, 0x4f, 0x56, 0x45, 0x52, 0x53, 0x49, 0x5a, 0x45, 0x2d, 0x44, 0x53, 0x54, - 0x2d, -]; +const LONG_DST_PREFIX: &[u8; 17] = b"H2C-OVERSIZE-DST-"; -pub(super) struct ExpanderXof { - pub(super) xofer: T, - pub(super) dst: Vec, - pub(super) k: usize, -} +/// Implements section [5.3.3](https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-16#section-5.3.3) +/// "Using DSTs longer than 255 bytes" of the +/// [IRTF CFRG hash-to-curve draft #16](https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-16#section-5.3.3). +pub struct DST(arrayvec::ArrayVec); -impl Expander for ExpanderXof { - fn construct_dst_prime(&self) -> Vec { - let mut dst_prime = if self.dst.len() > MAX_DST_LENGTH { - let mut xofer = self.xofer.clone(); - xofer.update(&LONG_DST_PREFIX.clone()); - xofer.update(&self.dst); - xofer.finalize_boxed((2 * self.k + 7) >> 3).to_vec() +impl DST { + pub fn new_xmd(dst: &[u8]) -> DST { + let array = if dst.len() > MAX_DST_LENGTH { + let mut long = H::default(); + long.update(&LONG_DST_PREFIX[..]); + long.update(&dst); + ArrayVec::try_from(long.finalize_fixed().as_ref()).unwrap() } else { - self.dst.clone() + ArrayVec::try_from(dst).unwrap() }; - dst_prime.push(dst_prime.len() as u8); - dst_prime + DST(array) } - fn expand(&self, msg: &[u8], n: usize) -> Vec { - let dst_prime = self.construct_dst_prime(); - let lib_str = &[((n >> 8) & 0xFF) as u8, (n & 0xFF) as u8]; - let mut xofer = self.xofer.clone(); + pub fn new_xof(dst: &[u8], k: usize) -> DST { + let array = if dst.len() > MAX_DST_LENGTH { + let mut long = H::default(); + long.update(&LONG_DST_PREFIX[..]); + long.update(&dst); + + let mut new_dst = [0u8; MAX_DST_LENGTH]; + let new_dst = &mut new_dst[0..((2 * k + 7) >> 3)]; + long.finalize_xof_into(new_dst); + ArrayVec::try_from(&*new_dst).unwrap() + } else { + ArrayVec::try_from(dst).unwrap() + }; + DST(array) + } + + pub fn update(&self, h: &mut H) { + h.update(self.0.as_ref()); + // I2OSP(len,1) https://www.rfc-editor.org/rfc/rfc8017.txt + h.update(&[self.0.len() as u8]); + } +} + +pub(super) struct ExpanderXof { + pub(super) xofer: PhantomData, + pub(super) dst: Vec, + pub(super) k: usize, +} + +impl Expander for ExpanderXof { + fn expand(&self, msg: &[u8], n: usize) -> Vec { + let mut xofer = H::default(); xofer.update(msg); - xofer.update(lib_str); - xofer.update(&dst_prime); - xofer.finalize_boxed(n).to_vec() + + // I2OSP(len,2) https://www.rfc-editor.org/rfc/rfc8017.txt + let lib_str = (n as u16).to_be_bytes(); + xofer.update(&lib_str); + + DST::new_xof::(self.dst.as_ref(), self.k).update(&mut xofer); + xofer.finalize_boxed(n).into_vec() } } -pub(super) struct ExpanderXmd { - pub(super) hasher: T, +pub(super) struct ExpanderXmd { + pub(super) hasher: PhantomData, pub(super) dst: Vec, pub(super) block_size: usize, } -impl Expander for ExpanderXmd { - fn construct_dst_prime(&self) -> Vec { - let mut dst_prime = if self.dst.len() > MAX_DST_LENGTH { - let mut hasher = self.hasher.clone(); - hasher.update(&LONG_DST_PREFIX); - hasher.update(&self.dst); - hasher.finalize_reset().to_vec() - } else { - self.dst.clone() - }; - dst_prime.push(dst_prime.len() as u8); - dst_prime - } +static Z_PAD: [u8; 256] = [0u8; 256]; + +impl Expander for ExpanderXmd { fn expand(&self, msg: &[u8], n: usize) -> Vec { - let mut hasher = self.hasher.clone(); + use digest::typenum::Unsigned; // output size of the hash function, e.g. 32 bytes = 256 bits for sha2::Sha256 - let b_len = hasher.output_size(); + let b_len = H::OutputSize::to_usize(); let ell = (n + (b_len - 1)) / b_len; assert!( ell <= 255, "The ratio of desired output to the output size of hash function is too large!" ); - let dst_prime = self.construct_dst_prime(); - let z_pad: Vec = vec![0; self.block_size]; + let dst_prime = DST::new_xmd::(self.dst.as_ref()); // Represent `len_in_bytes` as a 2-byte array. // As per I2OSP method outlined in https://tools.ietf.org/pdf/rfc8017.pdf, // The program should abort if integer that we're trying to convert is too large. assert!(n < (1 << 16), "Length should be smaller than 2^16"); let lib_str: [u8; 2] = (n as u16).to_be_bytes(); - hasher.update(&z_pad); + let mut hasher = H::default(); + hasher.update(&Z_PAD[0..self.block_size]); hasher.update(msg); hasher.update(&lib_str); hasher.update(&[0u8]); - hasher.update(&dst_prime); - let b0 = hasher.finalize_reset(); + dst_prime.update(&mut hasher); + let b0 = hasher.finalize_fixed_reset(); hasher.update(&b0); hasher.update(&[1u8]); - hasher.update(&dst_prime); - let mut bi = hasher.finalize_reset(); + dst_prime.update(&mut hasher); + let mut bi = hasher.finalize_fixed_reset(); let mut uniform_bytes: Vec = Vec::with_capacity(n); uniform_bytes.extend_from_slice(&bi); @@ -103,11 +123,12 @@ impl Expander for ExpanderXmd { hasher.update(&[*l ^ *r]); } hasher.update(&[i as u8]); - hasher.update(&dst_prime); - bi = hasher.finalize_reset(); + dst_prime.update(&mut hasher); + bi = hasher.finalize_fixed_reset(); uniform_bytes.extend_from_slice(&bi); } - uniform_bytes[0..n].to_vec() + uniform_bytes.truncate(n); + uniform_bytes } } diff --git a/ff/src/fields/field_hashers/expander/tests.rs b/ff/src/fields/field_hashers/expander/tests.rs index 36b4190f9..eeebc64a7 100644 --- a/ff/src/fields/field_hashers/expander/tests.rs +++ b/ff/src/fields/field_hashers/expander/tests.rs @@ -5,6 +5,7 @@ use sha3::{Shake128, Shake256}; use std::{ fs::{read_dir, File}, io::BufReader, + marker::PhantomData, }; use super::{Expander, ExpanderXmd, ExpanderXof}; @@ -99,29 +100,29 @@ fn get_expander(id: ExpID, _dst: &[u8], k: usize) -> Box { match id { ExpID::XMD(h) => match h { HashID::SHA256 => Box::new(ExpanderXmd { - hasher: Sha256::default(), + hasher: PhantomData::, block_size: 64, dst, }), HashID::SHA384 => Box::new(ExpanderXmd { - hasher: Sha384::default(), + hasher: PhantomData::, block_size: 128, dst, }), HashID::SHA512 => Box::new(ExpanderXmd { - hasher: Sha512::default(), + hasher: PhantomData::, block_size: 128, dst, }), }, ExpID::XOF(x) => match x { XofID::SHAKE128 => Box::new(ExpanderXof { - xofer: Shake128::default(), + xofer: PhantomData::, k, dst, }), XofID::SHAKE256 => Box::new(ExpanderXof { - xofer: Shake256::default(), + xofer: PhantomData::, k, dst, }), diff --git a/ff/src/fields/field_hashers/mod.rs b/ff/src/fields/field_hashers/mod.rs index bfd44f231..f8bd0b26b 100644 --- a/ff/src/fields/field_hashers/mod.rs +++ b/ff/src/fields/field_hashers/mod.rs @@ -1,9 +1,10 @@ mod expander; +use core::marker::PhantomData; + use crate::{Field, PrimeField}; -use ark_std::vec::Vec; -use digest::DynDigest; +use digest::{FixedOutputReset, XofReader}; use expander::Expander; use self::expander::ExpanderXmd; @@ -17,8 +18,8 @@ pub trait HashToField: Sized { /// * `domain` - bytes that get concatenated with the `msg` during hashing, in order to separate potentially interfering instantiations of the hasher. fn new(domain: &[u8]) -> Self; - /// Hash an arbitrary `msg` to #`count` elements from field `F`. - fn hash_to_field(&self, msg: &[u8], count: usize) -> Vec; + /// Hash an arbitrary `msg` to `N` elements of the field `F`. + fn hash_to_field(&self, msg: &[u8]) -> [F; N]; } /// This field hasher constructs a Hash-To-Field based on a fixed-output hash function, @@ -33,16 +34,16 @@ pub trait HashToField: Sized { /// use sha2::Sha256; /// /// let hasher = as HashToField>::new(&[1, 2, 3]); -/// let field_elements: Vec = hasher.hash_to_field(b"Hello, World!", 2); +/// let field_elements: [Fq; 2] = hasher.hash_to_field(b"Hello, World!"); /// /// assert_eq!(field_elements.len(), 2); /// ``` -pub struct DefaultFieldHasher { +pub struct DefaultFieldHasher { expander: ExpanderXmd, len_per_base_elem: usize, } -impl HashToField +impl HashToField for DefaultFieldHasher { fn new(dst: &[u8]) -> Self { @@ -51,7 +52,7 @@ impl HashToFie let len_per_base_elem = get_len_per_elem::(); let expander = ExpanderXmd { - hasher: H::default(), + hasher: PhantomData, dst: dst.to_vec(), block_size: len_per_base_elem, }; @@ -62,38 +63,49 @@ impl HashToFie } } - fn hash_to_field(&self, message: &[u8], count: usize) -> Vec { + fn hash_to_field(&self, message: &[u8]) -> [F; N] { let m = F::extension_degree() as usize; - // The user imposes a `count` of elements of F_p^m to output per input msg, + // The user requests `N` of elements of F_p^m to output per input msg, // each field element comprising `m` BasePrimeField elements. - let len_in_bytes = count * m * self.len_per_base_elem; + let len_in_bytes = N * m * self.len_per_base_elem; let uniform_bytes = self.expander.expand(message, len_in_bytes); - let mut output = Vec::with_capacity(count); - let mut base_prime_field_elems = Vec::with_capacity(m); - for i in 0..count { - base_prime_field_elems.clear(); - for j in 0..m { + let cb = |i| { + let base_prime_field_elem = |j| { let elm_offset = self.len_per_base_elem * (j + i * m); - let val = F::BasePrimeField::from_be_bytes_mod_order( + F::BasePrimeField::from_be_bytes_mod_order( &uniform_bytes[elm_offset..][..self.len_per_base_elem], - ); - base_prime_field_elems.push(val); - } - let f = F::from_base_prime_field_elems(base_prime_field_elems.drain(..)).unwrap(); - output.push(f); - } - - output + ) + }; + F::from_base_prime_field_elems((0..m).map(base_prime_field_elem)).unwrap() + }; + ark_std::array::from_fn::(cb) } } +pub fn hash_to_field(h: &mut H) -> F { + // The final output of `hash_to_field` will be an array of field + // elements from F::BaseField, each of size `len_per_elem`. + let len_per_base_elem = get_len_per_elem::(); + // Rust *still* lacks alloca, hence this ugly hack. + let mut alloca = [0u8; 2048]; + let alloca = &mut alloca[0..len_per_base_elem]; + + let m = F::extension_degree() as usize; + + let base_prime_field_elem = |_| { + h.read(alloca); + F::BasePrimeField::from_be_bytes_mod_order(alloca) + }; + F::from_base_prime_field_elems((0..m).map(base_prime_field_elem)).unwrap() +} + /// This function computes the length in bytes that a hash function should output /// for hashing an element of type `Field`. /// See section 5.1 and 5.3 of the /// [IETF hash standardization draft](https://datatracker.ietf.org/doc/draft-irtf-cfrg-hash-to-curve/14/) -fn get_len_per_elem() -> usize { +const fn get_len_per_elem() -> usize { // ceil(log(p)) let base_field_size_in_bits = F::BasePrimeField::MODULUS_BIT_SIZE as usize; // ceil(log(p)) + security_parameter diff --git a/test-templates/src/h2c/mod.rs b/test-templates/src/h2c/mod.rs index 4cd52eb35..77b9a1a42 100644 --- a/test-templates/src/h2c/mod.rs +++ b/test-templates/src/h2c/mod.rs @@ -52,11 +52,11 @@ macro_rules! test_h2c { for v in data.vectors.iter() { // first, hash-to-field tests - let got: Vec<$base_prime_field> = - hasher.hash_to_field(&v.msg.as_bytes(), 2 * $m); + let got: [$base_prime_field; { 2 * $m }] = + hasher.hash_to_field(&v.msg.as_bytes()); let want: Vec<$base_prime_field> = v.u.iter().map(read_fq_vec).flatten().collect(); - assert_eq!(got, want); + assert_eq!(got[..], *want); // then, test curve points let x = read_fq_vec(&v.p.x);