Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MillerLoopResultScalingFactor struct generic on T. #194

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions hydra/garaga/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,22 +1015,45 @@ def extract_from_circuit_output(
) -> str:
raise NotImplementedError("Never used in practice")

def serialize_input_signature(self) -> str:
bits = self.bits
if bits <= 288:
return f"{self.name}:MillerLoopResultScalingFactor<u288>"
else:
return f"{self.name}:MillerLoopResultScalingFactor<u384>"

def dump_to_circuit_input(self) -> str:
code = ""
bits = self.bits
if bits <= 288:
next_fn = "next_u288"
else:
next_fn = "next_2"
for mem_name in self.members_names:
code += f"circuit_inputs = circuit_inputs.next_2({self.name}.{mem_name});\n"
code += (
f"circuit_inputs = circuit_inputs.{next_fn}({self.name}.{mem_name});\n"
)
return code

def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 6
raw_struct = f"{self.__class__.__name__}{{{','.join([f'{self.members_names[i]}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
bits = self.bits
if bits <= 288:
curve_id = 0
else:
curve_id = 1
raw_struct = f"{self.__class__.__name__}{{{','.join([f'{self.members_names[i]}: {int_to_u2XX(self.elmts[i].value, curve_id=curve_id)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"

def _serialize_to_calldata(self) -> list[int]:
return io.bigint_split_array(self.elmts, prepend_length=False)
bits = self.bits
if bits <= 288:
return io.bigint_split_array(self.elmts, n_limbs=3, prepend_length=False)
else:
return io.bigint_split_array(self.elmts, n_limbs=4, prepend_length=False)

def __len__(self) -> int:
if self.elmts is not None:
Expand Down
16 changes: 8 additions & 8 deletions src/src/circuits/multi_pairing_check.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2213,7 +2213,7 @@ fn run_BLS12_381_MP_CHECK_INIT_BIT_3P_2F_circuit(
return (Q0, new_lhs);
}
fn run_BLS12_381_MP_CHECK_PREPARE_LAMBDA_ROOT_circuit(
lambda_root_inverse: E12D<u384>, z: u384, scaling_factor: MillerLoopResultScalingFactor
lambda_root_inverse: E12D<u384>, z: u384, scaling_factor: MillerLoopResultScalingFactor<u384>
) -> (u384, u384, u384) {
// CONSTANT stack
let in0 = CE::<CI<0>> {}; // 0x0
Expand Down Expand Up @@ -5334,7 +5334,7 @@ fn run_BN254_MP_CHECK_INIT_BIT_3P_2F_circuit(
fn run_BN254_MP_CHECK_PREPARE_LAMBDA_ROOT_circuit(
lambda_root: E12D<u288>,
z: u384,
scaling_factor: MillerLoopResultScalingFactor,
scaling_factor: MillerLoopResultScalingFactor<u288>,
c_inv: E12D<u288>,
c_0: u384
) -> (u384, u384, u384, u384, u384, u384, u384) {
Expand Down Expand Up @@ -5678,12 +5678,12 @@ fn run_BN254_MP_CHECK_PREPARE_LAMBDA_ROOT_circuit(
circuit_inputs = circuit_inputs.next_u288(lambda_root.w10); // in55
circuit_inputs = circuit_inputs.next_u288(lambda_root.w11); // in56
circuit_inputs = circuit_inputs.next_2(z); // in57
circuit_inputs = circuit_inputs.next_2(scaling_factor.w0); // in58
circuit_inputs = circuit_inputs.next_2(scaling_factor.w2); // in59
circuit_inputs = circuit_inputs.next_2(scaling_factor.w4); // in60
circuit_inputs = circuit_inputs.next_2(scaling_factor.w6); // in61
circuit_inputs = circuit_inputs.next_2(scaling_factor.w8); // in62
circuit_inputs = circuit_inputs.next_2(scaling_factor.w10); // in63
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w0); // in58
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w2); // in59
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w4); // in60
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w6); // in61
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w8); // in62
circuit_inputs = circuit_inputs.next_u288(scaling_factor.w10); // in63
circuit_inputs = circuit_inputs.next_u288(c_inv.w0); // in64
circuit_inputs = circuit_inputs.next_u288(c_inv.w1); // in65
circuit_inputs = circuit_inputs.next_u288(c_inv.w2); // in66
Expand Down
14 changes: 7 additions & 7 deletions src/src/definitions.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,13 @@ impl E12DSerde288 of Serde<E12D<u288>> {
}

#[derive(Copy, Drop, Debug, PartialEq, Serde)]
struct MillerLoopResultScalingFactor {
w0: u384,
w2: u384,
w4: u384,
w6: u384,
w8: u384,
w10: u384,
struct MillerLoopResultScalingFactor<T> {
w0: T,
w2: T,
w4: T,
w6: T,
w8: T,
w10: T,
}
#[derive(Copy, Drop, Debug, PartialEq, Serde)]
struct E12DMulQuotient {
Expand Down
4 changes: 2 additions & 2 deletions src/src/groth16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ fn multi_pairing_check_bn254_3P_2F_with_extra_miller_loop_result(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair2, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(mpcheck_hint.lambda_root, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(mpcheck_hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(mpcheck_hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u288(mpcheck_hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u288_transcript(mpcheck_hint.Ris, s0, s1, s2);

Expand Down Expand Up @@ -514,7 +514,7 @@ fn multi_pairing_check_bls12_381_3P_2F_with_extra_miller_loop_result(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair1, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair2, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u384(hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u384(hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u384_transcript(hint.Ris, s0, s1, s2);
let mut c_i: u384 = s1.into();
Expand Down
8 changes: 4 additions & 4 deletions src/src/pairing_check.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ use garaga::basic_field_ops::{compute_yInvXnegOverY_BN254, compute_yInvXnegOverY
struct MPCheckHintBN254 {
lambda_root: E12D<u288>,
lambda_root_inverse: E12D<u288>,
w: MillerLoopResultScalingFactor,
w: MillerLoopResultScalingFactor<u288>,
Ris: Span<E12D<u288>>,
big_Q: Array<u288>,
}

#[derive(Drop, Serde)]
struct MPCheckHintBLS12_381 {
lambda_root_inverse: E12D<u384>,
w: MillerLoopResultScalingFactor,
w: MillerLoopResultScalingFactor<u384>,
Ris: Span<E12D<u384>>,
big_Q: Array<u384>,
}
Expand All @@ -70,7 +70,7 @@ fn multi_pairing_check_bn254_2P_2F(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair1, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(hint.lambda_root, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u288(hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u288(hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u288_transcript(hint.Ris, s0, s1, s2);
let mut c_i: u384 = s1.into();
Expand Down Expand Up @@ -233,7 +233,7 @@ fn multi_pairing_check_bls12_381_2P_2F(
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair0, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_G1G2Pair(pair1, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_E12D_u384(hint.lambda_root_inverse, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor(hint.w, s0, s1, s2);
let (s0, s1, s2) = hashing::hash_MillerLoopResultScalingFactor_u384(hint.w, s0, s1, s2);
// Hash Ris to obtain base random coefficient c0
let (s0, s1, s2) = hashing::hash_E12D_u384_transcript(hint.Ris, s0, s1, s2);

Expand Down
24 changes: 12 additions & 12 deletions src/src/tests/pairing_tests.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4097,12 +4097,12 @@ mod pairing_tests {
}
},
w: MillerLoopResultScalingFactor {
w0: u384 { limb0: 0x1, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w2: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w4: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w6: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w8: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w10: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 }
w0: u288 { limb0: 0x1, limb1: 0x0, limb2: 0x0 },
w2: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w4: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w6: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w8: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w10: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 }
},
Ris: array![
E12D {
Expand Down Expand Up @@ -11917,12 +11917,12 @@ mod pairing_tests {
}
},
w: MillerLoopResultScalingFactor {
w0: u384 { limb0: 0x1, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w2: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w4: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w6: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w8: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 },
w10: u384 { limb0: 0x0, limb1: 0x0, limb2: 0x0, limb3: 0x0 }
w0: u288 { limb0: 0x1, limb1: 0x0, limb2: 0x0 },
w2: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w4: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w6: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w8: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 },
w10: u288 { limb0: 0x0, limb1: 0x0, limb2: 0x0 }
},
Ris: array![
E12D {
Expand Down
30 changes: 28 additions & 2 deletions src/src/utils/hashing.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ pub fn hash_E12D_u288(

// Apply sponge construction to a MillerLoopResultScalingFactor element from an initial state (s0,
// s1, s2)
pub fn hash_MillerLoopResultScalingFactor(
elmt: MillerLoopResultScalingFactor, mut s0: felt252, mut s1: felt252, mut s2: felt252
pub fn hash_MillerLoopResultScalingFactor_u384(
elmt: MillerLoopResultScalingFactor<u384>, mut s0: felt252, mut s1: felt252, mut s2: felt252
) -> (felt252, felt252, felt252) {
let base: felt252 = 79228162514264337593543950336; // 2**96

Expand All @@ -203,6 +203,32 @@ pub fn hash_MillerLoopResultScalingFactor(
return (_s0, _s1, _s2);
}

pub fn hash_MillerLoopResultScalingFactor_u288(
elmt: MillerLoopResultScalingFactor<u288>, mut s0: felt252, mut s1: felt252, mut s2: felt252
) -> (felt252, felt252, felt252) {
let base: felt252 = 79228162514264337593543950336; // 2**96

let in_1 = s0 + elmt.w0.limb0.into() + base * elmt.w0.limb1.into();
let in_2 = s1 + elmt.w0.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, s2);
let in_1 = _s0 + elmt.w2.limb0.into() + base * elmt.w2.limb1.into();
let in_2 = _s1 + elmt.w2.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w4.limb0.into() + base * elmt.w4.limb1.into();
let in_2 = _s1 + elmt.w4.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w6.limb0.into() + base * elmt.w6.limb1.into();
let in_2 = _s1 + elmt.w6.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w8.limb0.into() + base * elmt.w8.limb1.into();
let in_2 = _s1 + elmt.w8.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
let in_1 = _s0 + elmt.w10.limb0.into() + base * elmt.w10.limb1.into();
let in_2 = _s1 + elmt.w10.limb2.into();
let (_s0, _s1, _s2) = hades_permutation(in_1, in_2, _s2);
return (_s0, _s1, _s2);
}

// Apply sponge construction to a sequence of E12D elements from an initial state (s0, s1, s2)
pub fn hash_E12D_u384_transcript(
transcript: Span<E12D<u384>>, mut s0: felt252, mut s1: felt252, mut s2: felt252
Expand Down