Skip to content

Commit

Permalink
feat: add VariableByteArray (#88)
Browse files Browse the repository at this point in the history
* feat: add VariableByteArray

* fix: correct type in panic msg

* feat: make MAX_VAR_LEN const generic

* feat: add `SafeBool` and `SafeByte` types

These are very common so we have separate wrapper to avoid the extra length 1
vector heap allocation.

* wip: add VarLenBytes

* Refactor VarLenBytes
Add VarLenBytesVec and FixLenBytes
Fix tests

* Add unsafe methods for bytes
Address NITs

---------

Co-authored-by: Jonathan Wang <[email protected]>
Co-authored-by: Xinding Wei <[email protected]>
  • Loading branch information
3 people authored Aug 18, 2023
1 parent a7b5433 commit f724c9b
Show file tree
Hide file tree
Showing 8 changed files with 488 additions and 10 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Cargo.lock

# These are backup files generated by rustfmt
**/*.rs.bk

# Local IDE configs
.idea/
.vscode/
=======
/target

Expand Down
7 changes: 6 additions & 1 deletion halo2-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ rayon = "1.7"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
log = "0.4"
getset = "0.1.2"

# Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on
halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true }
Expand Down Expand Up @@ -50,7 +51,11 @@ mimalloc = { version = "0.1", default-features = false, optional = true }
[features]
default = ["halo2-axiom", "display"]
asm = ["halo2_proofs_axiom?/asm"]
dev-graph = ["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"]
dev-graph = [
"halo2_proofs?/dev-graph",
"halo2_proofs_axiom?/dev-graph",
"plotters",
]
halo2-pse = ["halo2_proofs/circuit-params"]
halo2-axiom = ["halo2_proofs_axiom"]
display = []
Expand Down
90 changes: 90 additions & 0 deletions halo2-base/src/safe_types/bytes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#![allow(clippy::len_without_is_empty)]
use crate::AssignedValue;

use super::{SafeByte, ScalarField};

use getset::Getters;

/// Represents a variable length byte array in circuit.
///
/// Each element is guaranteed to be a byte, given by type [`SafeByte`].
/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide.
/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s).
#[derive(Debug, Clone, Getters)]
pub struct VarLenBytes<F: ScalarField, const MAX_LEN: usize> {
/// The byte array, right padded
#[getset(get = "pub")]
bytes: [SafeByte<F>; MAX_LEN],
/// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN`
#[getset(get = "pub")]
len: AssignedValue<F>,
}

impl<F: ScalarField, const MAX_LEN: usize> VarLenBytes<F, MAX_LEN> {
// VarLenBytes can be only created by SafeChip.
pub(super) fn new(bytes: [SafeByte<F>; MAX_LEN], len: AssignedValue<F>) -> Self {
assert!(
len.value().le(&F::from(MAX_LEN as u64)),
"Invalid length which exceeds MAX_LEN {MAX_LEN}",
);
Self { bytes, len }
}

/// Returns the maximum length of the byte array.
pub fn max_len(&self) -> usize {
MAX_LEN
}
}

/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
///
/// Each element is guaranteed to be a byte, given by type [`SafeByte`].
/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is provided when constructing and `bytes.len()` == `MAX_LEN` is enforced.
/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s).
#[derive(Debug, Clone, Getters)]
pub struct VarLenBytesVec<F: ScalarField> {
/// The byte array, right padded
#[getset(get = "pub")]
bytes: Vec<SafeByte<F>>,
/// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN`
#[getset(get = "pub")]
len: AssignedValue<F>,
}

impl<F: ScalarField> VarLenBytesVec<F> {
// VarLenBytesVec can be only created by SafeChip.
pub(super) fn new(bytes: Vec<SafeByte<F>>, len: AssignedValue<F>, max_len: usize) -> Self {
assert!(
len.value().le(&F::from_u128(max_len as u128)),
"Invalid length which exceeds MAX_LEN {}",
max_len
);
assert!(bytes.len() == max_len, "bytes is not padded correctly");
Self { bytes, len }
}

/// Returns the maximum length of the byte array.
pub fn max_len(&self) -> usize {
self.bytes.len()
}
}

/// Represents a fixed length byte array in circuit.
#[derive(Debug, Clone, Getters)]
pub struct FixLenBytes<F: ScalarField, const LEN: usize> {
/// The byte array
#[getset(get = "pub")]
bytes: [SafeByte<F>; LEN],
}

impl<F: ScalarField, const LEN: usize> FixLenBytes<F, LEN> {
// FixLenBytes can be only created by SafeChip.
pub(super) fn new(bytes: [SafeByte<F>; LEN]) -> Self {
Self { bytes }
}

/// Returns the length of the byte array.
pub fn len(&self) -> usize {
LEN
}
}
150 changes: 142 additions & 8 deletions halo2-base/src/safe_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@ pub use crate::{
flex_gate::GateInstructions,
range::{RangeChip, RangeInstructions},
},
safe_types::VarLenBytes,
utils::ScalarField,
AssignedValue, Context,
QuantumCell::{self, Constant, Existing, Witness},
};
use std::cmp::{max, min};
use std::{
borrow::{Borrow, BorrowMut},
cmp::{max, min},
};

mod bytes;
mod primitives;

pub use bytes::*;
use itertools::Itertools;
pub use primitives::*;

#[cfg(test)]
pub mod tests;
Expand Down Expand Up @@ -54,20 +65,26 @@ impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>
Self { value: raw_values }
}

/// Return values in littile-endian.
pub fn value(&self) -> &RawAssignedValues<F> {
/// Return values in little-endian.
pub fn value(&self) -> &[AssignedValue<F>] {
&self.value
}
}

impl<F: ScalarField, const BYTES_PER_ELE: usize, const TOTAL_BITS: usize> AsRef<[AssignedValue<F>]>
for SafeType<F, BYTES_PER_ELE, TOTAL_BITS>
{
fn as_ref(&self) -> &[AssignedValue<F>] {
self.value()
}
}

/// Represent TOTAL_BITS with the least number of AssignedValue<F>.
/// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1]
#[allow(type_alias_bounds)]
type CompactSafeType<F: ScalarField, const TOTAL_BITS: usize> =
SafeType<F, { ((F::NUM_BITS - 1) / 8) as usize }, TOTAL_BITS>;
SafeType<F, { (F::CAPACITY / 8) as usize }, TOTAL_BITS>;

/// SafeType for bool.
pub type SafeBool<F> = CompactSafeType<F, 1>;
/// SafeType for uint8.
pub type SafeUint8<F> = CompactSafeType<F, 8>;
/// SafeType for uint16.
Expand Down Expand Up @@ -98,7 +115,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
Self { range_chip }
}

/// Convert a vector of AssignedValue(treated as little-endian) to a SafeType.
/// Convert a vector of AssignedValue (treated as little-endian) to a SafeType.
/// The number of bytes of inputs must equal to the number of bytes of outputs.
/// This function also add contraints that a AssignedValue in inputs must be in the range of a byte.
pub fn raw_bytes_to<const BYTES_PER_ELE: usize, const TOTAL_BITS: usize>(
Expand Down Expand Up @@ -134,6 +151,123 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
SafeType::<F, BYTES_PER_ELE, TOTAL_BITS>::new(value)
}

/// Constrains that the `input` is a boolean value (either 0 or 1) and wraps it in [`SafeBool`].
pub fn assert_bool(&self, ctx: &mut Context<F>, input: AssignedValue<F>) -> SafeBool<F> {
self.range_chip.gate().assert_bit(ctx, input);
SafeBool(input)
}

/// Load a boolean value as witness and constrain it is either 0 or 1.
pub fn load_bool(&self, ctx: &mut Context<F>, input: bool) -> SafeBool<F> {
let input = ctx.load_witness(F::from(input));
self.assert_bool(ctx, input)
}

/// Unsafe method that directly converts `input` to [`SafeBool`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeBool`].
pub fn unsafe_to_bool(&self, input: AssignedValue<F>) -> SafeBool<F> {
SafeBool(input)
}

/// Constrains that the `input` is a byte value and wraps it in [`SafeByte`].
pub fn assert_byte(&self, ctx: &mut Context<F>, input: AssignedValue<F>) -> SafeByte<F> {
self.range_chip.range_check(ctx, input, BITS_PER_BYTE);
SafeByte(input)
}

/// Load a boolean value as witness and constrain it is either 0 or 1.
pub fn load_byte(&self, ctx: &mut Context<F>, input: u8) -> SafeByte<F> {
let input = ctx.load_witness(F::from(input as u64));
self.assert_byte(ctx, input)
}

/// Unsafe method that directly converts `input` to [`SafeByte`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
pub fn unsafe_to_byte(input: AssignedValue<F>) -> SafeByte<F> {
SafeByte(input)
}

/// Unsafe method that directly converts `inputs` to [`VarLenBytes`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
pub fn unsafe_to_var_len_bytes<const MAX_LEN: usize>(
inputs: [AssignedValue<F>; MAX_LEN],
len: AssignedValue<F>,
) -> VarLenBytes<F, MAX_LEN> {
VarLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)), len)
}

/// Unsafe method that directly converts `inputs` to [`VarLenBytesVec`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
pub fn unsafe_to_var_len_bytes_vec(
inputs: RawAssignedValues<F>,
len: AssignedValue<F>,
max_len: usize,
) -> VarLenBytesVec<F> {
VarLenBytesVec::<F>::new(
inputs.iter().map(|input| Self::unsafe_to_byte(*input)).collect_vec(),
len,
max_len,
)
}

/// Unsafe method that directly converts `inputs` to [`FixLenBytes`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
pub fn unsafe_to_fix_len_bytes<const MAX_LEN: usize>(
inputs: [AssignedValue<F>; MAX_LEN],
) -> FixLenBytes<F, MAX_LEN> {
FixLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)))
}

/// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Slice representing the byte array.
/// * len: [AssignedValue]<F> witness representing the variable elements within the byte array from 0..=len.
/// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain.
pub fn raw_to_var_len_bytes<const MAX_LEN: usize>(
&self,
ctx: &mut Context<F>,
inputs: [AssignedValue<F>; MAX_LEN],
len: AssignedValue<F>,
) -> VarLenBytes<F, MAX_LEN> {
self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64);
VarLenBytes::<F, MAX_LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)), len)
}

/// Converts a vector of AssignedValue(treated as little-endian) to VarLenBytesVec. Not encourged to use because `MAX_LEN` cannot be verified at compile time.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Vector representing the byte array.
/// * len: [AssignedValue]<F> witness representing the variable elements within the byte array from 0..=len.
/// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain.
pub fn raw_to_var_len_bytes_vec(
&self,
ctx: &mut Context<F>,
inputs: RawAssignedValues<F>,
len: AssignedValue<F>,
max_len: usize,
) -> VarLenBytesVec<F> {
self.range_chip.check_less_than_safe(ctx, len, max_len as u64);
VarLenBytesVec::<F>::new(
inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(),
len,
max_len,
)
}

/// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Slice representing the byte array.
/// * LEN: length of the byte array.
pub fn raw_to_fix_len_bytes<const LEN: usize>(
&self,
ctx: &mut Context<F>,
inputs: [AssignedValue<F>; LEN],
) -> FixLenBytes<F, LEN> {
FixLenBytes::<F, LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)))
}

fn add_bytes_constraints(
&self,
ctx: &mut Context<F>,
Expand All @@ -148,6 +282,6 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
}
}

// TODO: Add comprasion. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool
// TODO: Add comparison. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool
// TODO: Add type castings. e.g. uint256 -> bytes32/uint32 -> uint64
}
47 changes: 47 additions & 0 deletions halo2-base/src/safe_types/primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use super::*;
/// SafeType for bool (1 bit).
///
/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because
/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid
/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector.
#[derive(Clone, Copy, Debug)]
pub struct SafeBool<F: ScalarField>(pub(super) AssignedValue<F>);

/// SafeType for byte (8 bits).
///
/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because
/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid
/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector.
#[derive(Clone, Copy, Debug)]
pub struct SafeByte<F: ScalarField>(pub(super) AssignedValue<F>);

macro_rules! safe_primitive_impls {
($SafePrimitive:ty) => {
impl<F: ScalarField> AsRef<AssignedValue<F>> for $SafePrimitive {
fn as_ref(&self) -> &AssignedValue<F> {
&self.0
}
}

impl<F: ScalarField> AsMut<AssignedValue<F>> for $SafePrimitive {
fn as_mut(&mut self) -> &mut AssignedValue<F> {
&mut self.0
}
}

impl<F: ScalarField> Borrow<AssignedValue<F>> for $SafePrimitive {
fn borrow(&self) -> &AssignedValue<F> {
&self.0
}
}

impl<F: ScalarField> BorrowMut<AssignedValue<F>> for $SafePrimitive {
fn borrow_mut(&mut self) -> &mut AssignedValue<F> {
&mut self.0
}
}
};
}

safe_primitive_impls!(SafeBool<F>);
safe_primitive_impls!(SafeByte<F>);
Loading

0 comments on commit f724c9b

Please sign in to comment.