diff --git a/src/primitives/iter.rs b/src/primitives/iter.rs index 65b4ff9d..7056cb1f 100644 --- a/src/primitives/iter.rs +++ b/src/primitives/iter.rs @@ -56,7 +56,26 @@ pub trait Fe32IterExt: Sized + Iterator { /// are simply dropped. #[inline] fn fes_to_bytes(mut self) -> FesToBytes { - FesToBytes { last_fe: self.next(), bit_offset: 0, iter: self } + FesToBytes { + last_fe: self.next(), + bit_offset: 0, + iter: self, + output_incomplete_bits_zeropad: false + } + } + + /// Adapts the `Fe32` iterator to output bytes instead. + /// + /// If the total number of bits is not a multiple of 8, trailing bits + /// are padded with the needed amount of zeroes and converted. + #[inline] + fn fes_to_bytes_zeropad(mut self) -> FesToBytes { + FesToBytes { + last_fe: self.next(), + bit_offset: 0, + iter: self, + output_incomplete_bits_zeropad: true + } } /// Adapts the Fe32 iterator to encode the field elements into a bech32 address. @@ -148,7 +167,8 @@ where /// Iterator adaptor that converts GF32 elements to bytes. /// -/// If the total number of bits is not a multiple of 8, any trailing bits are dropped. +/// If the total number of bits is not a multiple of 8. Any trailing bits are dropped, +/// unless `output_incomplete_bits_zeropad` is set, in which case they are padded with zeroes. /// /// Note that if there are 5 or more trailing bits, the result will be that an entire field element /// is dropped. If this occurs, the input was an invalid length for a bech32 string, but this @@ -158,6 +178,7 @@ pub struct FesToBytes> { last_fe: Option, bit_offset: usize, iter: I, + output_incomplete_bits_zeropad: bool, } impl Iterator for FesToBytes @@ -177,10 +198,18 @@ where let mut ret = last.0 << (3 + bit_offset); self.last_fe = self.iter.next(); - let next1 = self.last_fe?; + let next1 = if !self.output_incomplete_bits_zeropad { + self.last_fe? + } else { + self.last_fe.unwrap_or_default() + }; if bit_offset > 2 { self.last_fe = self.iter.next(); - let next2 = self.last_fe?; + let next2 = if !self.output_incomplete_bits_zeropad { + self.last_fe? + } else { + self.last_fe.unwrap_or_default() + }; ret |= next1.0 << (bit_offset - 2); ret |= next2.0 >> (7 - bit_offset); } else { @@ -503,4 +532,32 @@ mod tests { const FES: [Fe32; 3] = [Fe32::Q, Fe32::P, Fe32::Q]; assert!(FES.iter().copied().fes_to_bytes().bytes_to_fes().eq(FES.iter().copied())) } + + #[test] + fn fe32_iter_ext_zeropad_and_nozeropad() { + use std::convert::TryFrom; + { + // Difference is 1 output byte, containing 4 trailing bits + let fes_iter = [0, 1, 2, 1].iter().copied().map(|b| Fe32::try_from(b).unwrap()); + assert_eq!(fes_iter.clone().fes_to_bytes_zeropad().collect::>(), [0, 68, 16]); + assert_eq!(fes_iter.clone().fes_to_bytes().collect::>(), [0, 68]); + } + { + // Difference is 1 output byte, containing 1 trailing bit + let fes_iter = [0, 1, 2, 3, 31].iter().copied().map(|b| Fe32::try_from(b).unwrap()); + assert_eq!( + fes_iter.clone().fes_to_bytes_zeropad().collect::>(), + [0, 68, 63, 128] + ); + assert_eq!(fes_iter.clone().fes_to_bytes().collect::>(), [0, 68, 63]); + } + { + // No difference here, as the input (32*5=160 bits) has no trailing bits + let fes_iter = "w508d6qejxtdg4y5r3zarvary0c5xw7k" + .bytes() + .map(|b| Fe32::from_char(char::from(b)).unwrap()); + assert_eq!(fes_iter.clone().fes_to_bytes_zeropad().collect::>(), DATA); + assert_eq!(fes_iter.clone().fes_to_bytes().collect::>(), DATA); + } + } }