Skip to content

Commit

Permalink
fix: Allen's Poseidon2 fixes (#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue authored Jul 17, 2024
1 parent eda543c commit e62e8a1
Show file tree
Hide file tree
Showing 14 changed files with 209 additions and 72 deletions.
4 changes: 2 additions & 2 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,12 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
_ => unimplemented!(),
}
}
DslIr::Poseidon2AbsorbBabyBear(p2_hash_num, input) => match input {
DslIr::Poseidon2AbsorbBabyBear(p2_hash_and_absorb_num, input) => match input {
Array::Dyn(input, input_size) => {
if let Usize::Var(input_size) = input_size {
self.push(
AsmInstruction::Poseidon2Absorb(
p2_hash_num.fp(),
p2_hash_and_absorb_num.fp(),
input.fp(),
input_size.fp(),
),
Expand Down
30 changes: 16 additions & 14 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,17 +854,19 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
false,
"".to_string(),
),
AsmInstruction::Poseidon2Absorb(hash_num, input_ptr, input_len) => Instruction::new(
Opcode::Poseidon2Absorb,
i32_f(hash_num),
i32_f_arr(input_ptr),
i32_f_arr(input_len),
F::zero(),
F::zero(),
false,
false,
"".to_string(),
),
AsmInstruction::Poseidon2Absorb(hash_and_absorb_num, input_ptr, input_len) => {
Instruction::new(
Opcode::Poseidon2Absorb,
i32_f(hash_and_absorb_num),
i32_f_arr(input_ptr),
i32_f_arr(input_len),
F::zero(),
F::zero(),
false,
false,
"".to_string(),
)
}
AsmInstruction::Poseidon2Finalize(hash_num, output_ptr) => Instruction::new(
Opcode::Poseidon2Finalize,
i32_f(hash_num),
Expand Down Expand Up @@ -1174,15 +1176,15 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
result, src1, src2
)
}
AsmInstruction::Poseidon2Absorb(hash_num, input_ptr, input_len) => {
AsmInstruction::Poseidon2Absorb(hash_and_absorb_num, input_ptr, input_len) => {
write!(
f,
"poseidon2_absorb ({})fp, {})fp, ({})fp",
hash_num, input_ptr, input_len,
hash_and_absorb_num, input_ptr, input_len,
)
}
AsmInstruction::Poseidon2Finalize(hash_num, output_ptr) => {
write!(f, "poseidon2_finalize ({})fp, {})fp", hash_num, output_ptr,)
write!(f, "poseidon2_finalize ({})fp, ({})fp", hash_num, output_ptr,)
}
AsmInstruction::Commit(val, index) => {
write!(f, "commit ({})fp ({})fp", val, index)
Expand Down
18 changes: 14 additions & 4 deletions recursion/compiler/src/ir/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ impl<C: Config> Builder<C> {
/// Applies the Poseidon2 absorb function to the given array.
///
/// Reference: [p3_symmetric::PaddingFreeSponge]
pub fn poseidon2_absorb(&mut self, p2_hash_num: Var<C::N>, input: &Array<C, Felt<C::F>>) {
self.operations
.push(DslIr::Poseidon2AbsorbBabyBear(p2_hash_num, input.clone()));
pub fn poseidon2_absorb(
&mut self,
p2_hash_and_absorb_num: Var<C::N>,
input: &Array<C, Felt<C::F>>,
) {
self.operations.push(DslIr::Poseidon2AbsorbBabyBear(
p2_hash_and_absorb_num,
input.clone(),
));
}

/// Applies the Poseidon2 finalize to the given hash number.
Expand Down Expand Up @@ -128,9 +134,13 @@ impl<C: Config> Builder<C> {
self.cycle_tracker("poseidon2-hash");

let p2_hash_num = self.p2_hash_num;
let two_power_12: Var<_> = self.eval(C::N::from_canonical_u32(1 << 12));

self.range(0, array.len()).for_each(|i, builder| {
let subarray = builder.get(array, i);
builder.poseidon2_absorb(p2_hash_num, &subarray);
let p2_hash_and_absorb_num: Var<_> = builder.eval(p2_hash_num * two_power_12 + i);

builder.poseidon2_absorb(p2_hash_and_absorb_num, &subarray);
});

let output: Array<C, Felt<C::F>> = self.dyn_array(DIGEST_SIZE);
Expand Down
122 changes: 96 additions & 26 deletions recursion/core/src/poseidon2_wide/air/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
local_row.syscall_params(),
send_range_check,
);

builder
.when(local_control_flow.is_syscall_row)
.assert_one(local_is_real);
}

/// This function will verify that all hash rows are before the compress rows and that the first
Expand All @@ -80,47 +84,67 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
local_is_real: AB::Expr,
next_is_real: AB::Expr,
) {
// We require that the first row is an absorb syscall and that the hash_num == 0.
// We require that the first row is an absorb syscall and that the hash_num == 0 and absorb_num == 0.
let mut first_row_builder = builder.when_first_row();
first_row_builder.assert_one(local_control_flow.is_absorb);
first_row_builder.assert_one(local_control_flow.is_syscall_row);
first_row_builder.assert_zero(local_syscall_params.absorb().hash_num);
first_row_builder.assert_zero(local_opcode_workspace.absorb().hash_num);
first_row_builder.assert_zero(local_opcode_workspace.absorb().absorb_num);
first_row_builder.assert_one(local_opcode_workspace.absorb().is_first_hash_row);

let mut transition_builder = builder.when_transition();

// For absorb rows, constrain the following:
// 1) next row is either an absorb or syscall finalize.
// 2) when last absorb row, then the next row is a syscall row.
// 2) hash_num == hash_num'.
// 1) when last absorb row, then the next row is a either an absorb or finalize syscall row.
// 2) when last absorb row and the next row is an absorb row, then absorb_num' = absorb_num + 1.
// 3) when not last absorb row, then the next row is an absorb non syscall row.
// 4) when not last absorb row, then absorb_num' = absorb_num.
// 5) hash_num == hash_num'.
{
let mut absorb_transition_builder =
transition_builder.when(local_control_flow.is_absorb);
absorb_transition_builder
let mut transition_builder = builder.when_transition();

let mut absorb_last_row_builder =
transition_builder.when(local_control_flow.is_absorb_last_row);
absorb_last_row_builder
.assert_one(next_control_flow.is_absorb + next_control_flow.is_finalize);
absorb_transition_builder
.when(local_opcode_workspace.absorb().is_last_row::<AB>())
.assert_one(next_control_flow.is_syscall_row);
absorb_last_row_builder.assert_one(next_control_flow.is_syscall_row);
absorb_last_row_builder
.when(next_control_flow.is_absorb)
.assert_eq(
next_opcode_workspace.absorb().absorb_num,
local_opcode_workspace.absorb().absorb_num + AB::Expr::one(),
);

let mut absorb_not_last_row_builder =
transition_builder.when(local_control_flow.is_absorb_not_last_row);
absorb_not_last_row_builder.assert_one(next_control_flow.is_absorb);
absorb_not_last_row_builder.assert_zero(next_control_flow.is_syscall_row);
absorb_not_last_row_builder.assert_eq(
local_opcode_workspace.absorb().absorb_num,
next_opcode_workspace.absorb().absorb_num,
);

let mut absorb_transition_builder =
transition_builder.when(local_control_flow.is_absorb);
absorb_transition_builder
.when(next_control_flow.is_absorb)
.assert_eq(
local_syscall_params.absorb().hash_num,
next_syscall_params.absorb().hash_num,
local_opcode_workspace.absorb().hash_num,
next_opcode_workspace.absorb().hash_num,
);
absorb_transition_builder
.when(next_control_flow.is_finalize)
.assert_eq(
local_syscall_params.absorb().hash_num,
local_opcode_workspace.absorb().hash_num,
next_syscall_params.finalize().hash_num,
);
}

// For finalize rows, constrain the following:
// 1) next row is syscall compress or syscall absorb.
// 2) if next row is absorb -> hash_num + 1 == hash_num'
// 3) if next row is absorb -> is_first_hash' == true
// 3) if next row is absorb -> absorb_num' == 0
// 4) if next row is absorb -> is_first_hash' == true
{
let mut transition_builder = builder.when_transition();
let mut finalize_transition_builder =
transition_builder.when(local_control_flow.is_finalize);

Expand All @@ -132,8 +156,11 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
.when(next_control_flow.is_absorb)
.assert_eq(
local_syscall_params.finalize().hash_num + AB::Expr::one(),
next_syscall_params.absorb().hash_num,
next_opcode_workspace.absorb().hash_num,
);
finalize_transition_builder
.when(next_control_flow.is_absorb)
.assert_zero(next_opcode_workspace.absorb().absorb_num);
finalize_transition_builder
.when(next_control_flow.is_absorb)
.assert_one(next_opcode_workspace.absorb().is_first_hash_row);
Expand All @@ -143,26 +170,33 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
// 1) if compress syscall -> next row is a compress output
// 2) if compress output -> next row is a compress syscall or not real
{
builder.assert_eq(
local_control_flow.is_compress_output,
local_control_flow.is_compress
* (AB::Expr::one() - local_control_flow.is_syscall_row),
);

let mut transition_builder = builder.when_transition();

transition_builder
.when(local_control_flow.is_compress)
.when(local_control_flow.is_syscall_row)
.assert_one(next_control_flow.is_compress_output);

// When we are at a compress output row, then ensure next row is either not real or is a compress syscall row.
transition_builder
.when(local_control_flow.is_compress_output)
.assert_one(
next_control_flow.is_compress + (AB::Expr::one() - next_is_real.clone()),
(AB::Expr::one() - next_is_real.clone())
+ next_control_flow.is_compress * next_control_flow.is_syscall_row,
);

transition_builder
.when(local_control_flow.is_compress_output)
.when(next_control_flow.is_compress)
.assert_one(next_control_flow.is_syscall_row);
}

// Constrain that there is only one is_real -> not is real transition. Also contrain that
// the last real row is a compress output row.
{
let mut transition_builder = builder.when_transition();

transition_builder
.when_not(local_is_real.clone())
.assert_zero(next_is_real.clone());
Expand Down Expand Up @@ -194,6 +228,29 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
let last_row_ending_cursor_is_seven =
local_hash_workspace.last_row_ending_cursor_is_seven.result;

// Verify that the hash_num and absorb_num are correctly decomposed from the syscall
// hash_and_absorb_num param.
// Also range check that both hash_num is within [0, 2^16 - 1] and absorb_num is within [0, 2^12 - 1];
{
let mut absorb_builder = builder.when(local_control_flow.is_absorb);

absorb_builder.assert_eq(
local_hash_workspace.hash_num * AB::Expr::from_canonical_u32(1 << 12)
+ local_hash_workspace.absorb_num,
local_syscall_params.absorb().hash_and_absorb_num,
);
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8),
local_hash_workspace.hash_num,
send_range_check,
);
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U12 as u8),
local_hash_workspace.absorb_num,
send_range_check,
);
}

// Constrain the materialized control flow flags.
{
let mut absorb_builder = builder.when(local_control_flow.is_absorb);
Expand Down Expand Up @@ -232,12 +289,16 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
local_control_flow.is_absorb
* (AB::Expr::one() - local_hash_workspace.is_last_row::<AB>()),
);
builder.assert_eq(
local_control_flow.is_absorb_last_row,
local_control_flow.is_absorb * local_hash_workspace.is_last_row::<AB>(),
);

builder.assert_eq(
local_control_flow.is_absorb_no_perm,
local_control_flow.is_absorb
* (AB::Expr::one() - local_hash_workspace.do_perm::<AB>()),
)
);
}

// For the absorb syscall row, ensure correct value of num_remaining_rows, last_row_num_consumed,
Expand Down Expand Up @@ -274,7 +335,16 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
expected_last_row_ending_cursor,
);

// Range check that num_remaining_rows is between [0, 2^18-1].
// Range check that input_len < 2^16. This check is only needed for absorb syscall rows,
// but we send it for all absorb rows, since the `is_real` parameter must be an expression
// with at most degree 1.
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8),
local_syscall_params.absorb().input_len,
send_range_check,
);

// Range check that num_remaining_rows is between [0, 2^16-1].
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8),
local_hash_workspace.num_remaining_rows,
Expand Down
11 changes: 10 additions & 1 deletion recursion/core/src/poseidon2_wide/air/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,21 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
}

// Verify that all elements of start_mem_idx_bitmap and end_mem_idx_bitmap are bool.
// Also verify that exactly one of the bits in start_mem_idx_bitmap and end_mem_idx_bitmap
// is one.
let mut start_mem_idx_bitmap_sum = AB::Expr::zero();
start_mem_idx_bitmap.iter().for_each(|bit| {
absorb_builder.assert_bool(*bit);
start_mem_idx_bitmap_sum += (*bit).into();
});
absorb_builder.assert_one(start_mem_idx_bitmap_sum);

let mut end_mem_idx_bitmap_sum = AB::Expr::zero();
end_mem_idx_bitmap.iter().for_each(|bit| {
absorb_builder.assert_bool(*bit);
end_mem_idx_bitmap_sum += (*bit).into();
});
absorb_builder.assert_one(end_mem_idx_bitmap_sum);

// Verify correct value of start_mem_idx_bitmap and end_mem_idx_bitmap.
let start_mem_idx: AB::Expr = start_mem_idx_bitmap
Expand All @@ -209,7 +218,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
// When we are not in the last row, end_mem_idx should be zero.
absorb_builder
.when_not(opcode_workspace.absorb().is_last_row::<AB>())
.assert_zero(end_mem_idx.clone());
.assert_zero(end_mem_idx.clone() - AB::Expr::from_canonical_usize(7));

// When we are in the last row, end_mem_idx bitmap should equal last_row_ending_cursor.
absorb_builder
Expand Down
4 changes: 2 additions & 2 deletions recursion/core/src/poseidon2_wide/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! # Layout of the poseidon2 chip:
//!
//! All the hash related rows should be in the first part of the chip and all the compress
//! related rows in the second part. E.g. the chip should has this format:
//! related rows in the second part. E.g. the chip should have this format:
//!
//! absorb row (for hash num 1)
//! absorb row (for hash num 1)
Expand Down Expand Up @@ -34,7 +34,7 @@
//! last_row_ending_cursor will be copied down to all of the rows. Also, for the next absorb/finalize
//! syscall, its state_cursor is set to (last_row_ending_cursor + 1) % RATE.
//!
//! From num_remaining_rows and syscall column, we know the absorb 's first row and last row.
//! From num_remaining_rows and syscall column, we know the absorb's first row and last row.
//! From that fact, we can then enforce the following state writes.
//!
//! 1. is_first_row && is_last_row -> state writes are [state_cursor..state_cursor + last_row_ending_cursor]
Expand Down
6 changes: 4 additions & 2 deletions recursion/core/src/poseidon2_wide/air/syscall_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
let next_syscall_params = next_syscall.absorb();

absorb_syscall_builder.assert_eq(local_syscall_params.clk, next_syscall_params.clk);
absorb_syscall_builder
.assert_eq(local_syscall_params.hash_num, next_syscall_params.hash_num);
absorb_syscall_builder.assert_eq(
local_syscall_params.hash_and_absorb_num,
next_syscall_params.hash_and_absorb_num,
);
absorb_syscall_builder.assert_eq(
local_syscall_params.input_ptr,
next_syscall_params.input_ptr,
Expand Down
2 changes: 2 additions & 0 deletions recursion/core/src/poseidon2_wide/columns/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub struct ControlFlow<T> {
pub is_absorb_no_perm: T,
/// Specifies if this row is for an absorb that is not the last row.
pub is_absorb_not_last_row: T,
/// Specifies if this row is for an absorb that is the last row.
pub is_absorb_last_row: T,

/// Specifies if this row is for finalize.
pub is_finalize: T,
Expand Down
2 changes: 2 additions & 0 deletions recursion/core/src/poseidon2_wide/columns/opcode_workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub struct AbsorbWorkspace<T: Copy> {
pub state_cursor: T,

/// Control flow columns.
pub hash_num: T,
pub absorb_num: T,
pub is_first_hash_row: T,
pub num_remaining_rows: T,
pub num_remaining_rows_is_zero: IsZeroOperation<T>,
Expand Down
Loading

0 comments on commit e62e8a1

Please sign in to comment.