From da5d143eee7da9595e697a6a44b536a6b9583c31 Mon Sep 17 00:00:00 2001 From: feltroidprime Date: Tue, 30 Jul 2024 16:35:01 +0200 Subject: [PATCH] test multi miller loop with precomputed lines --- hydra/precompiled_circuits/all_circuits.py | 4 +- .../fixed_G2_groth16_circuits.py | 1311 +++++++++++++++++ .../precompiled_circuits/multi_miller_loop.py | 358 ++++- .../multi_pairing_check.py | 15 +- tests/benchmarks.py | 2 +- .../hydra/circuits/test_multi_miller_loop.py | 115 ++ tools/starknet_cli.py | 4 +- 7 files changed, 1735 insertions(+), 74 deletions(-) create mode 100644 hydra/precompiled_circuits/fixed_G2_groth16_circuits.py create mode 100644 tests/hydra/circuits/test_multi_miller_loop.py diff --git a/hydra/precompiled_circuits/all_circuits.py b/hydra/precompiled_circuits/all_circuits.py index db00c373..a9bc4258 100644 --- a/hydra/precompiled_circuits/all_circuits.py +++ b/hydra/precompiled_circuits/all_circuits.py @@ -1014,7 +1014,7 @@ def _run_circuit_inner(self, input: list[PyFelt]): n_pairs=n_pairs, hash_input=True, ) - circuit.write_p_and_q(input) + circuit.write_p_and_q_raw(input) m = circuit.miller_loop(n_pairs) @@ -1057,7 +1057,7 @@ def _run_circuit_inner(self, input: list[PyFelt]): n_pairs=n_pairs, hash_input=True, ) - circuit.write_p_and_q(input) + circuit.write_p_and_q_raw(input) m, _, _, _, _ = circuit.multi_pairing_check(n_pairs) diff --git a/hydra/precompiled_circuits/fixed_G2_groth16_circuits.py b/hydra/precompiled_circuits/fixed_G2_groth16_circuits.py new file mode 100644 index 00000000..537d602c --- /dev/null +++ b/hydra/precompiled_circuits/fixed_G2_groth16_circuits.py @@ -0,0 +1,1311 @@ +from hydra.precompiled_circuits.all_circuits import ( + BaseEXTFCircuit, + PyFelt, + G1Point, + CurveID, + randint, + STARK, + multi_pairing_check, + u384, + G2Point, + ModuloCircuitElement, + G2PointCircuit, + E12D, + BN254_ID, + ModuloCircuit, + G1PointCircuit, + BNProcessedPair, + BLSProcessedPair, + BLS12_381_ID, + MillerLoopResultScalingFactor, + structs, + get_irreducible_poly, + ExtensionFieldModuloCircuit, +) + + +class Groth16Bit0Loop(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + self.n_pairs = 3 + super().__init__( + name=f"groth16_bit0", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + self.generic_over_curve = True + + def build_input(self) -> list[PyFelt]: + input = [] + for _ in range(self.n_pairs - 1): + p = G1Point.gen_random_point(CurveID(self.curve_id)) + yInv = self.field(p.y).__inv__() + xNegOverY = -self.field(p.x) * yInv + r0a0 = self.field.random() + r0a1 = self.field.random() + r1a0 = self.field.random() + r1a1 = self.field.random() + input.extend([yInv, xNegOverY, r0a0, r0a1, r1a0, r1a1]) + + p = G1Point.gen_random_point(CurveID(self.curve_id)) + current_q2 = G2Point.gen_random_point(CurveID(self.curve_id)) + yInv = self.field(p.y).__inv__() + xNegOverY = -self.field(p.x) * yInv + input.extend( + [ + yInv, + xNegOverY, + self.field(current_q2.x[0]), + self.field(current_q2.x[1]), + self.field(current_q2.y[0]), + self.field(current_q2.y[1]), + ] + ) + input.append( + self.field.random() + ) # LHS accumulation = Sum(ci*(prod(Pi,j)-Ri)(z)) + input.append(self.field.random()) # R(i-1)(z) = f(i-1)(z) for miller loop. + input.extend([self.field.random() for _ in range(12)]) # Ri = new_f + input.append(self.field.random()) # ci_minus_one + input.append(self.field(randint(0, STARK - 1))) # z + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + assert n_pairs >= 2, f"n_pairs must be >= 2, got {n_pairs}" + circuit: multi_pairing_check.MultiPairingCheckCircuit = ( + multi_pairing_check.MultiPairingCheckCircuit( + self.name, + self.curve_id, + n_pairs=n_pairs, + hash_input=False, + compilation_mode=self.compilation_mode, + ) + ) + # Parse (yInv, xNegOverY, Qx0, Qx1, Qy0, Qy1) * n_pairs + current_points = [] + for i in range(n_pairs): + circuit.yInv.append( + circuit.write_struct(u384(name=f"yInv_{i}", elmts=[input.pop(0)])) + ) + circuit.xNegOverY.append( + circuit.write_struct(u384(name=f"xNegOverY_{i}", elmts=[input.pop(0)])) + ) + current_pt = circuit.write_struct( + G2PointCircuit( + name=f"Q{i}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + current_points.append( + ( + [ + current_pt[0], + current_pt[1], + ], + [ + current_pt[2], + current_pt[3], + ], + ) + ) + + lhs_i = circuit.write_struct(u384(name="lhs_i", elmts=[input.pop(0)])) + f_i_of_z: ModuloCircuitElement = circuit.write_struct( + u384(name="f_i_of_z", elmts=[input.pop(0)]) + ) + + # f_i_plus_one is R as the result E12 for this bit. + f_i_plus_one: list[ModuloCircuitElement] = circuit.write_struct( + E12D(name="f_i_plus_one", elmts=[input.pop(0) for _ in range(12)]) + ) + ci = circuit.write_struct(u384(name="ci", elmts=[input.pop(0)])) + assert len(input) == 1, f"Input should be empty now" + + circuit.create_powers_of_Z( + circuit.write_struct(u384(name="z", elmts=[input.pop(0)])), max_degree=11 + ) + ci_plus_one = circuit.mul(ci, ci, f"Compute c_i = (c_(i-1))^2") + + assert len(input) == 0, f"Input should be empty now" + assert len(current_points) == n_pairs + + sum_i_prod_k_P = circuit.mul( + f_i_of_z, f_i_of_z, f"Square f evaluation in Z, the result of previous bit." + ) + new_points = [] + for k in range(n_pairs): + T, l1 = circuit.double_step(current_points[k], k) + sum_i_prod_k_P = circuit.mul( + sum_i_prod_k_P, + circuit.eval_poly_in_precomputed_Z(l1, circuit.line_sparsity), + f"Mul (f(z)^2 * Π_0_k-1(line_k(z))) * line_{k}(z)", + ) + + new_points.append(T) + + f_i_plus_one_of_z = circuit.eval_poly_in_precomputed_Z(f_i_plus_one) + new_lhs = circuit.mul( + ci_plus_one, + circuit.sub(sum_i_prod_k_P, f_i_plus_one_of_z, f"(Π(i,k) (Pk(z))) - Ri(z)"), + f"ci * ((Π(i,k) (Pk(z)) - Ri(z))", + ) + lhs_i_plus_one = circuit.add( + lhs_i, new_lhs, f"LHS = LHS + ci * ((Π(i,k) (Pk(z)) - Ri(z))" + ) + for i, point in enumerate(new_points): + # circuit.extend_output(point[0]) + # circuit.extend_output(point[1]) + circuit.extend_struct_output( + G2PointCircuit( + name=f"Q{i}", + elmts=[point[0][0], point[0][1], point[1][0], point[1][1]], + ) + ) + circuit.extend_struct_output( + u384(name="f_i_plus_one_of_z", elmts=[f_i_plus_one_of_z]) + ) + circuit.extend_struct_output( + u384(name="lhs_i_plus_one", elmts=[lhs_i_plus_one]) + ) + circuit.extend_struct_output(u384(name="ci_plus_one", elmts=[ci_plus_one])) + + return circuit + + +class MPCheckBit00Loop(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + n_pairs: int = 0, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + self.n_pairs = n_pairs + super().__init__( + name=f"mpcheck_bit00", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + self.generic_over_curve = True + + def build_input(self) -> list[PyFelt]: + input = [] + for _ in range(self.n_pairs): + p = G1Point.gen_random_point(CurveID(self.curve_id)) + current_q = G2Point.gen_random_point(CurveID(self.curve_id)) + yInv = self.field(p.y).__inv__() + xNegOverY = -self.field(p.x) * yInv + input.extend( + [ + yInv, + xNegOverY, + self.field(current_q.x[0]), + self.field(current_q.x[1]), + self.field(current_q.y[0]), + self.field(current_q.y[1]), + ] + ) + + input.append( + self.field.random() + ) # LHS accumulation = Sum(ci*(prod(Pi,j)-Ri)(z)) + input.append(self.field.random()) # R(i-1)(z) = f(i-1)(z) for miller loop. + input.extend([self.field.random() for _ in range(12)]) # Ri = new_f + input.append(self.field.random()) # ci_minus_one + input.append(self.field(randint(0, STARK - 1))) # z + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + assert n_pairs >= 2, f"n_pairs must be >= 2, got {n_pairs}" + circuit: multi_pairing_check.MultiPairingCheckCircuit = ( + multi_pairing_check.MultiPairingCheckCircuit( + self.name, + self.curve_id, + n_pairs=n_pairs, + hash_input=False, + compilation_mode=self.compilation_mode, + ) + ) + # Parse (yInv, xNegOverY, Qx0, Qx1, Qy0, Qy1) * n_pairs + current_points = [] + for i in range(n_pairs): + circuit.yInv.append( + circuit.write_struct(u384(name=f"yInv_{i}", elmts=[input.pop(0)])) + ) + circuit.xNegOverY.append( + circuit.write_struct(u384(name=f"xNegOverY_{i}", elmts=[input.pop(0)])) + ) + current_pt = circuit.write_struct( + G2PointCircuit( + name=f"Q{i}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + current_points.append( + ( + [ + current_pt[0], + current_pt[1], + ], + [ + current_pt[2], + current_pt[3], + ], + ) + ) + + lhs_i = circuit.write_struct(u384(name="lhs_i", elmts=[input.pop(0)])) + f_i_of_z: ModuloCircuitElement = circuit.write_struct( + u384(name="f_i_of_z", elmts=[input.pop(0)]) + ) + + # f_i_plus_one is R as the result E12 for this bit. + f_i_plus_one: list[ModuloCircuitElement] = circuit.write_struct( + E12D(name="f_i_plus_one", elmts=[input.pop(0) for _ in range(12)]) + ) + ci = circuit.write_struct(u384(name="ci", elmts=[input.pop(0)])) + assert len(input) == 1, f"Input should be empty now" + + circuit.create_powers_of_Z( + circuit.write_struct(u384(name="z", elmts=[input.pop(0)])), max_degree=11 + ) + ci_plus_one = circuit.mul(ci, ci, f"Compute c_i = (c_(i-1))^2") + + assert len(input) == 0, f"Input should be empty now" + assert len(current_points) == n_pairs + + sum_i_prod_k_P = circuit.mul( + f_i_of_z, f_i_of_z, f"Square f evaluation in Z, the result of previous bit." + ) + new_points = [] + for k in range(n_pairs): + T, l1 = circuit.double_step(current_points[k], k) + sum_i_prod_k_P = circuit.mul( + sum_i_prod_k_P, + circuit.eval_poly_in_precomputed_Z( + l1, circuit.line_sparsity, f"line_{k}" + ), + f"Mul (f(z)^2 * Π_0_k-1(line_k(z))) * line_i_{k}(z)", + ) + + new_points.append(T) + + sum_i_prod_k_P = circuit.mul( + sum_i_prod_k_P, + sum_i_prod_k_P, + "Compute (f^2 * Π(i,k) (line_i,k(z))) ^ 2 = f^4 * (Π(i,k) (line_i,k(z)))^2", + ) + + new_new_points = [] + for k in range(n_pairs): + T, l1 = circuit.double_step(new_points[k], k) + sum_i_prod_k_P = circuit.mul( + sum_i_prod_k_P, + circuit.eval_poly_in_precomputed_Z( + l1, circuit.line_sparsity, f"line_{k}" + ), + f"Mul (f^4 * (Π(i,k) (line_i,k(z)))^2) * line_i+1_{k}(z)", + ) + + new_new_points.append(T) + + f_i_plus_one_of_z = circuit.eval_poly_in_precomputed_Z( + f_i_plus_one, poly_name="R" + ) + new_lhs = circuit.mul( + ci_plus_one, + circuit.sub(sum_i_prod_k_P, f_i_plus_one_of_z, f"(Π(i,k) (Pk(z))) - Ri(z)"), + f"ci * ((Π(i,k) (Pk(z)) - Ri(z))", + ) + lhs_i_plus_one = circuit.add( + lhs_i, new_lhs, f"LHS = LHS + ci * ((Π(i,k) (Pk(z)) - Ri(z))" + ) + for i, point in enumerate(new_new_points): + # circuit.extend_output(point[0]) + # circuit.extend_output(point[1]) + circuit.extend_struct_output( + G2PointCircuit( + name=f"Q{i}", + elmts=[point[0][0], point[0][1], point[1][0], point[1][1]], + ) + ) + circuit.extend_struct_output( + u384(name="f_i_plus_one_of_z", elmts=[f_i_plus_one_of_z]) + ) + circuit.extend_struct_output( + u384(name="lhs_i_plus_one", elmts=[lhs_i_plus_one]) + ) + circuit.extend_struct_output(u384(name="ci_plus_one", elmts=[ci_plus_one])) + + return circuit + + +class MPCheckBit1Loop(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + n_pairs: int = 0, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + self.n_pairs = n_pairs + super().__init__( + name="mpcheck_bit1", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + self.generic_over_curve = True + + def build_input(self) -> list[PyFelt]: + """ + - (PyInv, PxNegOverY, Qx0, Qx1, Qy0, Qy1) * (n_pairs) + - LHS = sum(ci*prod(Pi,j))(z) (1 u384) + - sum (ci_-1*Ri_-1) (12 u384) + - f:E12 (12 u384) + - z (1 u384) + - ci (1 u384) + """ + input = [] + for _ in range(self.n_pairs): + p = G1Point.gen_random_point(CurveID(self.curve_id)) + current_q = G2Point.gen_random_point(CurveID(self.curve_id)) + q_or_q_neg = G2Point.gen_random_point(CurveID(self.curve_id)) + yInv = self.field(p.y).__inv__() + xNegOverY = -self.field(p.x) * yInv + input.extend( + [ + yInv, + xNegOverY, + self.field(current_q.x[0]), + self.field(current_q.x[1]), + self.field(current_q.y[0]), + self.field(current_q.y[1]), + self.field(q_or_q_neg.x[0]), + self.field(q_or_q_neg.x[1]), + self.field(q_or_q_neg.y[0]), + self.field(q_or_q_neg.y[1]), + ] + ) + + input.append( + self.field.random() + ) # LHS accumulation = Sum(ci*(prod(Pi,j)-Ri)(z)) + input.append(self.field.random()) # R(i-1)(z) = f(i-1)(z) for miller loop. + input.extend([self.field.random() for _ in range(12)]) # Ri = new_f + input.append(self.field.random()) # c_or_cinv_of_z + input.append(self.field(randint(0, STARK - 1))) # z + input.append(self.field.random()) # ci + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + assert n_pairs >= 2, f"n_pairs must be >= 2, got {n_pairs}" + circuit: multi_pairing_check.MultiPairingCheckCircuit = ( + multi_pairing_check.MultiPairingCheckCircuit( + self.name, + self.curve_id, + n_pairs=n_pairs, + hash_input=False, + compilation_mode=self.compilation_mode, + ) + ) + current_points = [] + q_or_q_neg_points = [] + for i in range(n_pairs): + circuit.yInv.append( + circuit.write_struct(u384(name=f"yInv_{i}", elmts=[input.pop(0)])) + ) + circuit.xNegOverY.append( + circuit.write_struct(u384(name=f"xNegOverY_{i}", elmts=[input.pop(0)])) + ) + curr_pt = circuit.write_struct( + G2PointCircuit( + name=f"Q{i}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + current_points.append( + ( + [ + curr_pt[0], + curr_pt[1], + ], + [ + curr_pt[2], + curr_pt[3], + ], + ) + ) + q_or_q_neg_pt = circuit.write_struct( + G2PointCircuit( + name=f"Q_or_Qneg_{i}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + q_or_q_neg_points.append( + ( + [ + q_or_q_neg_pt[0], + q_or_q_neg_pt[1], + ], + [ + q_or_q_neg_pt[2], + q_or_q_neg_pt[3], + ], + ) + ) + + lhs_i = circuit.write_struct(u384(name="lhs_i", elmts=[input.pop(0)])) + f_i_of_z: ModuloCircuitElement = circuit.write_struct( + u384(name="f_i_of_z", elmts=[input.pop(0)]) + ) + + f_i_plus_one: list[ModuloCircuitElement] = circuit.write_struct( + E12D(name="f_i_plus_one", elmts=[input.pop(0) for _ in range(12)]) + ) + c_or_cinv_of_z: ModuloCircuitElement = circuit.write_struct( + u384(name="c_or_cinv_of_z", elmts=[input.pop(0)]) + ) + circuit.create_powers_of_Z( + circuit.write_struct(u384(name="z", elmts=[input.pop(0)])), max_degree=11 + ) + assert len(input) == 1, f"Input should be empty now" + + ci = circuit.write_struct(u384(name="ci", elmts=[input.pop(0)])) + ci_plus_one = circuit.mul(ci, ci) + + assert len(input) == 0, f"Input should be empty now" + assert len(current_points) == n_pairs + + sum_i_prod_k_P_of_z = circuit.mul( + f_i_of_z, f_i_of_z + ) # Square f evaluation in Z, the result of previous bit. + new_points = [] + + for k in range(n_pairs): + T, l1, l2 = circuit.double_and_add_step( + current_points[k], q_or_q_neg_points[k], k + ) + sum_i_prod_k_P_of_z = circuit.mul( + sum_i_prod_k_P_of_z, + circuit.eval_poly_in_precomputed_Z( + l1, circuit.line_sparsity, f"line_{k}p_1" + ), + ) + sum_i_prod_k_P_of_z = circuit.mul( + sum_i_prod_k_P_of_z, + circuit.eval_poly_in_precomputed_Z( + l2, circuit.line_sparsity, f"line_{k}p_2" + ), + ) + new_points.append(T) + + sum_i_prod_k_P_of_z = circuit.mul(sum_i_prod_k_P_of_z, c_or_cinv_of_z) + + f_i_plus_one_of_z = circuit.eval_poly_in_precomputed_Z( + f_i_plus_one, poly_name="R" + ) + new_lhs = circuit.mul( + ci_plus_one, + circuit.sub(sum_i_prod_k_P_of_z, f_i_plus_one_of_z), + comment=f"ci * ((Π(i,k) (Pk(z)) - Ri(z))", + ) + lhs_i_plus_one = circuit.add(lhs_i, new_lhs) + + for i, point in enumerate(new_points): + circuit.extend_struct_output( + G2PointCircuit( + name=f"Q{i}", + elmts=[point[0][0], point[0][1], point[1][0], point[1][1]], + ) + ) + + circuit.extend_struct_output( + u384(name="f_i_plus_one_of_z", elmts=[f_i_plus_one_of_z]) + ) + circuit.extend_struct_output( + u384(name="lhs_i_plus_one", elmts=[lhs_i_plus_one]) + ) + circuit.extend_struct_output(u384(name="ci_plus_one", elmts=[ci_plus_one])) + + return circuit + + +class MPCheckPreparePairs(BaseEXTFCircuit): + """ + This circuit is used to prepare points for the multi-pairing check. + For BN curve, it will compute yInv and xNegOverY for each point + negate the y of the G2 point. + For BLS curve, it will only compute yInv and xNegOverY for each point. + """ + + def __init__( + self, + curve_id: int, + n_pairs: int = 0, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + self.n_pairs = n_pairs + super().__init__( + name="mpcheck_prepare_points", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + self.generic_over_curve = True + + def build_input(self) -> list[PyFelt]: + """ + - ((Px, Py) + (Qy0, Qy1) if BN curve) * (n_pairs) + """ + input = [] + for _ in range(self.n_pairs): + p = G1Point.gen_random_point(CurveID(self.curve_id)) + q = G2Point.gen_random_point(CurveID(self.curve_id)) + input.extend( + [ + self.field(p.x), + self.field(p.y), + ] + ) + if self.curve_id == BN254_ID: + input.extend( + [ + self.field(q.y[0]), + self.field(q.y[1]), + ] + ) + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + circuit: ModuloCircuit = ModuloCircuit( + self.name, + self.curve_id, + compilation_mode=self.compilation_mode, + ) + for i in range(n_pairs): + x, y = circuit.write_struct( + G1PointCircuit(name=f"p_{i}", elmts=[input.pop(0), input.pop(0)]) + ) + yInv = circuit.inv(y) + xNegOverY = circuit.neg(circuit.mul(x, yInv)) + + if self.curve_id == BN254_ID: + Qy0 = circuit.write_struct(u384(name=f"Qy0_{i}", elmts=[input.pop(0)])) + Qy1 = circuit.write_struct(u384(name=f"Qy1_{i}", elmts=[input.pop(0)])) + Qyneg0 = circuit.neg(Qy0) + Qyneg1 = circuit.neg(Qy1) + circuit.extend_struct_output( + BNProcessedPair( + name=f"p_{i}", elmts=[yInv, xNegOverY, Qyneg0, Qyneg1] + ) + ) + else: + circuit.extend_struct_output( + BLSProcessedPair(name=f"p_{i}", elmts=[yInv, xNegOverY]) + ) + + return circuit + + +class MPCheckPrepareLambdaRootEvaluations(BaseEXTFCircuit): + def __init__(self, curve_id: int, auto_run: bool = True, compilation_mode: int = 1): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + super().__init__( + name="mpcheck_lambda_root_eval", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + + def build_input(self) -> list[PyFelt]: + input = [] + input.extend([self.field.random() for _ in range(12)]) # lambda_root + input.append(self.field.random()) # z + input.extend([self.field.random() for _ in range(6)]) # w - scaling factor + if self.curve_id == BN254_ID: + input.extend([self.field.random() for _ in range(12)]) # c_inv + input.append(self.field.random()) # c_0 + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + circuit = multi_pairing_check.MultiPairingCheckCircuit( + name=self.name, + curve_id=self.curve_id, + n_pairs=2, # Unused + hash_input=False, + compilation_mode=self.compilation_mode, + ) + c_or_c_inv = circuit.write_struct( + E12D( + name=( + f"lambda_root_inverse" + if self.curve_id == BLS12_381_ID + else "lambda_root" + ), + elmts=[input.pop(0) for _ in range(12)], + ) + ) + z = circuit.write_struct(u384(name="z", elmts=[input.pop(0)])) + circuit.create_powers_of_Z(z, max_degree=11) + + if self.curve_id == BLS12_381_ID: + # Conjugate c_inverse for BLS. + c_or_c_inv = circuit.conjugate_e12d(c_or_c_inv) + + c_or_c_inv_of_z = circuit.eval_poly_in_precomputed_Z( + c_or_c_inv, poly_name=f"C_inv" if self.curve_id == BLS12_381_ID else "C" + ) + circuit.extend_struct_output( + u384( + name="c_inv_of_z" if self.curve_id == BLS12_381_ID else "c_of_z", + elmts=[c_or_c_inv_of_z], + ) + ) + + scaling_factor_sparsity = [ + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + ] # E6 subfield within E12 : to_direct(E6.random + w * E6.zero) + + scaling_factor_compressed = circuit.write_struct( + MillerLoopResultScalingFactor( + name="scaling_factor", elmts=[input.pop(0) for _ in range(6)] + ) + ) + scaling_factor = [ + scaling_factor_compressed[0], + None, + scaling_factor_compressed[1], + None, + scaling_factor_compressed[2], + None, + scaling_factor_compressed[3], + None, + scaling_factor_compressed[4], + None, + scaling_factor_compressed[5], + None, + ] + + scaling_factor_of_z = circuit.eval_poly_in_precomputed_Z( + scaling_factor, sparsity=scaling_factor_sparsity, poly_name="W" + ) + + circuit.extend_struct_output( + u384("scaling_factor_of_z", elmts=[scaling_factor_of_z]) + ) + + if self.curve_id == BLS12_381_ID: + c_inv = c_or_c_inv + # Compute needed frobenius: + c_inv_frob_1 = circuit.frobenius(c_inv, 1) + c_inv_frob_1_of_z = circuit.eval_poly_in_precomputed_Z( + c_inv_frob_1, poly_name="C_inv_frob_1" + ) + circuit.extend_struct_output( + u384("c_inv_frob_1_of_z", elmts=[c_inv_frob_1_of_z]) + ) + elif self.curve_id == BN254_ID: + c = c_or_c_inv + c_of_z = c_or_c_inv_of_z + c_inv = circuit.write_struct( + E12D(name="c_inv", elmts=[input.pop(0) for _ in range(12)]) + ) + c_0 = circuit.write_struct(u384(name="c_0", elmts=[input.pop(0)])) + c_inv_of_z = circuit.eval_poly_in_precomputed_Z(c_inv, poly_name=f"C_inv") + circuit.extend_struct_output(u384(name="c_inv_of_z", elmts=[c_inv_of_z])) + lhs = circuit.sub( + circuit.mul(c_of_z, c_inv_of_z), + circuit.set_or_get_constant(1), + comment="c_of_z * c_inv_of_z - 1", + ) + lhs = circuit.mul(lhs, c_0, comment="c_0 * (c_of_z * c_inv_of_z - 1)") + circuit.extend_struct_output(u384("lhs", elmts=[lhs])) + + # Compute needed frobenius: + c_inv_frob_1 = circuit.frobenius(c_inv, 1) + c_frob_2 = circuit.frobenius(c, 2) + c_inv_frob_3 = circuit.frobenius(c_inv, 3) + + c_inv_frob_1_of_z = circuit.eval_poly_in_precomputed_Z( + c_inv_frob_1, poly_name="C_inv_frob_1" + ) + c_frob_2_of_z = circuit.eval_poly_in_precomputed_Z( + c_frob_2, poly_name="C_frob_2" + ) + c_inv_frob_3_of_z = circuit.eval_poly_in_precomputed_Z( + c_inv_frob_3, poly_name="C_inv_frob_3" + ) + + circuit.extend_struct_output( + u384("c_inv_frob_1_of_z", elmts=[c_inv_frob_1_of_z]) + ) + circuit.extend_struct_output(u384("c_frob_2_of_z", elmts=[c_frob_2_of_z])) + circuit.extend_struct_output( + u384("c_inv_frob_3_of_z", elmts=[c_inv_frob_3_of_z]) + ) + + return circuit + + +class MPCheckInitBit(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + n_pairs: int, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + assert 2 <= n_pairs <= 3, f"n_pairs must be between 2 and 3, got {n_pairs}" + self.n_pairs = n_pairs + super().__init__( + name="mpcheck_init_bit", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + + def build_input(self) -> list[PyFelt]: + input = [] + for _ in range(self.n_pairs): + p = G1Point.gen_random_point(CurveID(self.curve_id)) + current_q = G2Point.gen_random_point(CurveID(self.curve_id)) + yInv = self.field(p.y).__inv__() + xNegOverY = -self.field(p.x) * yInv + input.extend( + [ + yInv, + xNegOverY, + self.field(current_q.x[0]), + self.field(current_q.x[1]), + self.field(current_q.y[0]), + self.field(current_q.y[1]), + ] + ) + + input.extend([self.field.random() for _ in range(12)]) # Ri = new_f + input.append(self.field.random()) # c0 + input.append(self.field.random()) # z + input.append( + self.field.random() + ) # c_inv_of_z (BN) or conjugate(c_inv)_of_z (BLS) + + if self.curve_id == BN254_ID: + input.append(self.field.random()) # previous_lhs + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + circuit: multi_pairing_check.MultiPairingCheckCircuit = ( + multi_pairing_check.MultiPairingCheckCircuit( + self.name, + self.curve_id, + n_pairs=n_pairs, + hash_input=False, + compilation_mode=self.compilation_mode, + ) + ) + current_points = [] + for i in range(n_pairs): + circuit.yInv.append( + circuit.write_struct(u384(name=f"yInv_{i}", elmts=[input.pop(0)])) + ) + circuit.xNegOverY.append( + circuit.write_struct(u384(name=f"xNegOverY_{i}", elmts=[input.pop(0)])) + ) + curr_pt = circuit.write_struct( + G2PointCircuit( + name=f"Q{i}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + current_points.append( + ( + [ + curr_pt[0], + curr_pt[1], + ], + [ + curr_pt[2], + curr_pt[3], + ], + ) + ) + + R_i = circuit.write_struct(E12D("R_i", elmts=[input.pop(0) for _ in range(12)])) + c0 = circuit.write_struct(u384("c0", elmts=[input.pop(0)])) + z = circuit.write_struct(u384("z", elmts=[input.pop(0)])) + c_inv_of_z = circuit.write_struct(u384("c_inv_of_z", elmts=[input.pop(0)])) + circuit.create_powers_of_Z(z, max_degree=11) + + f_i_plus_one_of_z = circuit.eval_poly_in_precomputed_Z(R_i, poly_name="R") + sum_i_prod_k_P_of_z = circuit.mul( + c_inv_of_z, c_inv_of_z + ) # # At initialisation, f=1/c so f^2 = 1/c^2 + new_points = [] + if self.curve_id == BN254_ID: + c_i = circuit.mul( + c0, c0 + ) # Second relation for BN at init bit, need to update c_i. + for k in range(n_pairs): + T, l1 = circuit.double_step(current_points[k], k) + sum_i_prod_k_P_of_z = circuit.mul( + sum_i_prod_k_P_of_z, + circuit.eval_poly_in_precomputed_Z( + l1, circuit.line_sparsity, f"line_{k}p_1" + ), + ) + new_points.append(T) + elif self.curve_id == BLS12_381_ID: + c_i = c0 # first relation for BLS at init bit, no need to update c_i. + # bit +1: multiply f^2 by 1/c + sum_i_prod_k_P_of_z = circuit.mul(c_inv_of_z, sum_i_prod_k_P_of_z) + for k in range(n_pairs): + T, l1, l2 = circuit.triple_step(current_points[k], k) + sum_i_prod_k_P_of_z = circuit.mul( + sum_i_prod_k_P_of_z, + circuit.eval_poly_in_precomputed_Z( + l1, circuit.line_sparsity, f"line_{k}p_1" + ), + ) + sum_i_prod_k_P_of_z = circuit.mul( + sum_i_prod_k_P_of_z, + circuit.eval_poly_in_precomputed_Z( + l2, circuit.line_sparsity, f"line_{k}p_2" + ), + ) + new_points.append(T) + + new_lhs = circuit.mul( + c_i, + circuit.sub(sum_i_prod_k_P_of_z, f_i_plus_one_of_z), + comment=f"ci * ((Π(i,k) (Pk(z)) - Ri(z))", + ) + + if self.curve_id == BLS12_381_ID: + new_lhs = new_lhs + elif self.curve_id == BN254_ID: + previous_lhs = circuit.write_struct( + u384("previous_lhs", elmts=[input.pop(0)]) + ) + new_lhs = circuit.add(new_lhs, previous_lhs) + + # OUTPUT + for i, point in enumerate(new_points): + circuit.extend_struct_output( + G2PointCircuit( + name=f"Q{i}", + elmts=[point[0][0], point[0][1], point[1][0], point[1][1]], + ) + ) + circuit.extend_struct_output(u384("new_lhs", elmts=[new_lhs])) + if self.curve_id == BN254_ID: + circuit.extend_struct_output(u384("c_i", elmts=[c_i])) + circuit.extend_struct_output( + u384("f_i_plus_one_of_z", elmts=[f_i_plus_one_of_z]) + ) + return circuit + + +class MPCheckFinalizeBN(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + n_pairs: int, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + assert 2 <= n_pairs <= 3, f"n_pairs must be between 2 and 3, got {n_pairs}" + self.n_pairs = n_pairs + self.max_q_degree = multi_pairing_check.get_max_Q_degree(curve_id, self.n_pairs) + + super().__init__( + name="mpcheck_finalize_bn", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + + def build_input(self) -> list[PyFelt]: + if self.curve_id == BLS12_381_ID: + return [] + input = [] + for _ in range(self.n_pairs): + original_p = G1Point.gen_random_point(CurveID(self.curve_id)) + original_q = G2Point.gen_random_point(CurveID(self.curve_id)) + yInv = self.field(original_p.y).__inv__() + xNegOverY = -self.field(original_p.x) * yInv + current_q = G2Point.gen_random_point(CurveID(self.curve_id)) + + input.extend( + [ + self.field(original_q.x[0]), + self.field(original_q.x[1]), + self.field(original_q.y[0]), + self.field(original_q.y[1]), + yInv, + xNegOverY, + self.field(current_q.x[0]), + self.field(current_q.x[1]), + self.field(current_q.y[0]), + self.field(current_q.y[1]), + ] + ) + + input.extend([self.field.random() for _ in range(12)]) # Ri = new_f + input.extend([self.field.random() for _ in range(12)]) # Ri = new_f + + input.append(self.field.random()) # c_i-1 + input.append(self.field.random()) # z + input.append(self.field.random()) # w_of_z + input.append(self.field.random()) # c_inv_frob_1_of_z + input.append(self.field.random()) # c_frob_2_of_z + input.append(self.field.random()) # c_inv_frob_3_of_z + input.append(self.field.random()) # previous_lhs + input.append(self.field.random()) # R_n_minus_3_of_z + + input.extend([self.field.random() for _ in range(self.max_q_degree + 1)]) + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + circuit: multi_pairing_check.MultiPairingCheckCircuit = ( + multi_pairing_check.MultiPairingCheckCircuit( + self.name, + self.curve_id, + n_pairs=n_pairs, + hash_input=False, + compilation_mode=self.compilation_mode, + ) + ) + if self.curve_id == BLS12_381_ID: + return circuit + + current_points = [] + for k in range(n_pairs): + original_Q = circuit.write_struct( + G2PointCircuit( + name=f"original_Q{k}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + circuit.Q.append( + ([original_Q[0], original_Q[1]], [original_Q[2], original_Q[3]]) + ) + yInv_k = circuit.write_struct(u384(f"yInv_{k}", elmts=[input.pop(0)])) + xNegOverY_k = circuit.write_struct( + u384(f"xNegOverY_{k}", elmts=[input.pop(0)]) + ) + circuit.yInv.append(yInv_k) + circuit.xNegOverY.append(xNegOverY_k) + curr_pt = circuit.write_struct( + G2PointCircuit( + name=f"Q{k}", + elmts=[input.pop(0), input.pop(0), input.pop(0), input.pop(0)], + ) + ) + current_points.append( + ( + [ + curr_pt[0], + curr_pt[1], + ], + [ + curr_pt[2], + curr_pt[3], + ], + ) + ) + + R_n_minus_2 = circuit.write_struct( + E12D("R_n_minus_2", elmts=[input.pop(0) for _ in range(12)]) + ) + R_n_minus_1 = circuit.write_struct( + E12D("R_n_minus_1", elmts=[input.pop(0) for _ in range(12)]) + ) + c_n_minus_3 = circuit.write_struct(u384("c_n_minus_3", elmts=[input.pop(0)])) + w_of_z = circuit.write_struct(u384("w_of_z", elmts=[input.pop(0)])) + z = circuit.write_struct(u384("z", elmts=[input.pop(0)])) + c_inv_frob_1_of_z = circuit.write_struct( + u384("c_inv_frob_1_of_z", elmts=[input.pop(0)]) + ) + c_frob_2_of_z = circuit.write_struct( + u384("c_frob_2_of_z", elmts=[input.pop(0)]) + ) + c_inv_frob_3_of_z = circuit.write_struct( + u384("c_inv_frob_3_of_z", elmts=[input.pop(0)]) + ) + previous_lhs = circuit.write_struct(u384("previous_lhs", elmts=[input.pop(0)])) + R_n_minus_3_of_z = circuit.write_struct( + u384("R_n_minus_3_of_z", elmts=[input.pop(0)]) + ) + Q = circuit.write_struct( + structs.u384Array( + "Q", elmts=[input.pop(0) for _ in range(self.max_q_degree + 1)] + ) + ) + + circuit.create_powers_of_Z(z, max_degree=self.max_q_degree) + + c_n_minus_2 = circuit.mul(c_n_minus_3, c_n_minus_3) + c_n_minus_1 = circuit.mul(c_n_minus_2, c_n_minus_2) + + R_n_minus_2_of_z = circuit.eval_poly_in_precomputed_Z( + R_n_minus_2, poly_name="R_n_minus_2" + ) + R_n_minus_1_of_z = circuit.eval_poly_in_precomputed_Z( + R_n_minus_1, poly_name="R_n_minus_1" + ) + + # Relation n-2 : f * lines + prod_k_P_of_z_n_minus_2 = R_n_minus_3_of_z # Init + lines = circuit.bn254_finalize_step(current_points) + for l in lines: + prod_k_P_of_z_n_minus_2 = circuit.mul( + prod_k_P_of_z_n_minus_2, + circuit.eval_poly_in_precomputed_Z( + l, circuit.line_sparsity, f"line_{k}" + ), + ) + + lhs_n_minus_2 = circuit.mul( + c_n_minus_2, + circuit.sub(prod_k_P_of_z_n_minus_2, R_n_minus_2_of_z), + comment=f"c_n_minus_2 * ((Π(n-2,k) (Pk(z)) - R_n_minus_2(z))", + ) + + # Relation n-1 (last one) : f * w * c_inv_frob_1 * c_frob_2 * c_inv_frob_3 + prod_k_P_of_z_n_minus_1 = circuit.mul(R_n_minus_2_of_z, c_inv_frob_1_of_z) + prod_k_P_of_z_n_minus_1 = circuit.mul(prod_k_P_of_z_n_minus_1, c_frob_2_of_z) + prod_k_P_of_z_n_minus_1 = circuit.mul( + prod_k_P_of_z_n_minus_1, c_inv_frob_3_of_z + ) + prod_k_P_of_z_n_minus_1 = circuit.mul(prod_k_P_of_z_n_minus_1, w_of_z) + + lhs_n_minus_1 = circuit.mul( + c_n_minus_1, + circuit.sub(prod_k_P_of_z_n_minus_1, R_n_minus_1_of_z), + comment=f"c_n_minus_1 * ((Π(n-1,k) (Pk(z)) - R_n_minus_1(z))", + ) + + _final_lhs = circuit.add(previous_lhs, lhs_n_minus_2) + final_lhs = circuit.add(_final_lhs, lhs_n_minus_1) + + Q_of_z = circuit.eval_poly_in_precomputed_Z(Q, poly_name="big_Q") + P_irr, P_irr_sparsity = circuit.write_sparse_constant_elements( + get_irreducible_poly(self.curve_id, 12).get_coeffs(), + ) + P_of_z = circuit.eval_poly_in_precomputed_Z( + P_irr, P_irr_sparsity, poly_name="P_irr" + ) + check = circuit.sub(final_lhs, circuit.mul(Q_of_z, P_of_z)) + + circuit.extend_struct_output(u384("final_check", elmts=[check])) + return circuit + + +class MPCheckFinalizeBLS(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + n_pairs: int, + # include_m: bool, + auto_run: bool = True, + compilation_mode: int = 1, + ): + assert compilation_mode == 1, "Compilation mode 1 is required for this circuit" + assert 2 <= n_pairs <= 3, f"n_pairs must be between 2 and 3, got {n_pairs}" + self.n_pairs = n_pairs + self.max_q_degree = multi_pairing_check.get_max_Q_degree(curve_id, self.n_pairs) + + super().__init__( + name="mpcheck_finalize_bls", + input_len=None, + curve_id=curve_id, + auto_run=auto_run, + compilation_mode=compilation_mode, + ) + + def build_input(self) -> list[PyFelt]: + if self.curve_id == BN254_ID: + return [] + input = [] + + input.extend([self.field.random() for _ in range(12)]) # R_n_minus_1 + + input.append(self.field.random()) # c_i-1 + input.append(self.field.random()) # z + input.append(self.field.random()) # w_of_z + input.append(self.field.random()) # c_inv_frob_1_of_z + input.append(self.field.random()) # previous_lhs + input.append(self.field.random()) # R_n_minus_2_of_z + + input.extend([self.field.random() for _ in range(self.max_q_degree + 1)]) + return input + + def _run_circuit_inner(self, input: list[PyFelt]): + n_pairs = self.n_pairs + circuit: multi_pairing_check.MultiPairingCheckCircuit = ( + multi_pairing_check.MultiPairingCheckCircuit( + self.name, + self.curve_id, + n_pairs=n_pairs, + hash_input=False, + compilation_mode=self.compilation_mode, + ) + ) + if self.curve_id == BN254_ID: + return circuit + + R_n_minus_1 = circuit.write_struct( + E12D("R_n_minus_1", elmts=[input.pop(0) for _ in range(12)]) + ) + c_n_minus_2 = circuit.write_struct(u384("c_n_minus_2", elmts=[input.pop(0)])) + w_of_z = circuit.write_struct(u384("w_of_z", elmts=[input.pop(0)])) + z = circuit.write_struct(u384("z", elmts=[input.pop(0)])) + c_inv_frob_1_of_z = circuit.write_struct( + u384("c_inv_frob_1_of_z", elmts=[input.pop(0)]) + ) + previous_lhs = circuit.write_struct(u384("previous_lhs", elmts=[input.pop(0)])) + R_n_minus_2_of_z = circuit.write_struct( + u384("R_n_minus_2_of_z", elmts=[input.pop(0)]) + ) + Q = circuit.write_struct( + structs.u384Array( + "Q", elmts=[input.pop(0) for _ in range(self.max_q_degree + 1)] + ) + ) + + circuit.create_powers_of_Z(z, max_degree=self.max_q_degree) + + c_n_minus_1 = circuit.mul(c_n_minus_2, c_n_minus_2) + + R_n_minus_1_of_z = circuit.eval_poly_in_precomputed_Z( + R_n_minus_1, poly_name="R_n_minus_1" + ) + + # Relation n-1 (last one) : f * w * c_inv_frob_1 + prod_k_P_of_z_n_minus_1 = circuit.mul(R_n_minus_2_of_z, c_inv_frob_1_of_z) + prod_k_P_of_z_n_minus_1 = circuit.mul(prod_k_P_of_z_n_minus_1, w_of_z) + + lhs_n_minus_1 = circuit.mul( + c_n_minus_1, + circuit.sub(prod_k_P_of_z_n_minus_1, R_n_minus_1_of_z), + comment=f"c_n_minus_1 * ((Π(n-1,k) (Pk(z)) - R_n_minus_1(z))", + ) + + final_lhs = circuit.add( + previous_lhs, lhs_n_minus_1, comment="previous_lhs + lhs_n_minus_1" + ) + P_irr, P_irr_sparsity = circuit.write_sparse_constant_elements( + get_irreducible_poly(self.curve_id, 12).get_coeffs(), + ) + P_of_z = circuit.eval_poly_in_precomputed_Z( + P_irr, P_irr_sparsity, poly_name="P_irr" + ) + + Q_of_z = circuit.eval_poly_in_precomputed_Z(Q, poly_name="big_Q") + + check = circuit.sub( + final_lhs, + circuit.mul(Q_of_z, P_of_z, comment="Q(z) * P(z)"), + comment="final_lhs - Q(z) * P(z)", + ) + + circuit.extend_struct_output(u384("final_check", elmts=[check])) + return circuit + + +class FP12MulAssertOne(BaseEXTFCircuit): + def __init__( + self, + curve_id: int, + auto_run: bool = True, + init_hash: int = None, + compilation_mode: int = 0, + ): + super().__init__( + "fp12_mul_assert_one", None, curve_id, auto_run, init_hash, compilation_mode + ) + + def build_input(self) -> list[PyFelt]: + input = [] + input.extend([self.field.random() for _ in range(12)]) # X + input.extend([self.field.random() for _ in range(12)]) # Y + input.extend([self.field.random() for _ in range(11)]) # Q + # R is known to be 1. + input.append(self.field.random()) # z + + return input + + def _run_circuit_inner(self, input: list[PyFelt]) -> ExtensionFieldModuloCircuit: + circuit = ExtensionFieldModuloCircuit( + self.name, + self.curve_id, + extension_degree=12, + init_hash=self.init_hash, + compilation_mode=self.compilation_mode, + ) + X = circuit.write_struct(E12D("X", elmts=[input.pop(0) for _ in range(12)])) + Y = circuit.write_struct(E12D("Y", elmts=[input.pop(0) for _ in range(12)])) + Q = circuit.write_struct( + structs.E12DMulQuotient("Q", elmts=[input.pop(0) for _ in range(11)]) + ) + assert len(input) == 1 + z = circuit.write_struct(u384("z", elmts=[input.pop(0)])) + assert len(input) == 0 + circuit.create_powers_of_Z(z, max_degree=12) + P_irr, P_irr_sparsity = circuit.write_sparse_constant_elements( + get_irreducible_poly(self.curve_id, 12).get_coeffs(), + ) + P_of_z = circuit.eval_poly_in_precomputed_Z( + P_irr, P_irr_sparsity, poly_name="P_irr" + ) + Q_of_z = circuit.eval_poly_in_precomputed_Z(Q, poly_name="Q") + R_of_z = circuit.set_or_get_constant(1) + + X_of_z = circuit.eval_poly_in_precomputed_Z(X, poly_name="X") + Y_of_z = circuit.eval_poly_in_precomputed_Z(Y, poly_name="Y") + check = circuit.sub( + circuit.mul(X_of_z, Y_of_z, comment="X(z) * Y(z)"), + circuit.mul(Q_of_z, P_of_z, comment="Q(z) * P(z)"), + comment="(X(z) * Y(z)) - (Q(z) * P(z))", + ) + check = circuit.sub(check, R_of_z, comment="(X(z) * Y(z) - Q(z) * P(z)) - 1") + circuit.extend_struct_output(u384("check", elmts=[check])) + + return circuit diff --git a/hydra/precompiled_circuits/multi_miller_loop.py b/hydra/precompiled_circuits/multi_miller_loop.py index 86f65f70..c939f317 100644 --- a/hydra/precompiled_circuits/multi_miller_loop.py +++ b/hydra/precompiled_circuits/multi_miller_loop.py @@ -4,12 +4,16 @@ CURVES, CurveID, precompute_lineline_sparsity, + G2Point, + G1Point, ) from hydra.extension_field_modulo_circuit import ( ExtensionFieldModuloCircuit, ModuloCircuitElement, PyFelt, ) +from hydra.hints.io import flatten +from typing import Iterator, Tuple class MultiMillerLoopCircuit(ExtensionFieldModuloCircuit): @@ -21,6 +25,8 @@ def __init__( hash_input: bool = True, init_hash: int = None, compilation_mode: int = 0, + precompute_lines: bool = False, + n_points_precomputed_lines: int = None, ): super().__init__( name=name, @@ -51,9 +57,52 @@ def __init__( "MUL_BY_LL": 0, } ) - self.output_lines_sparsities = [] + self.n_points_precomputed_lines = n_points_precomputed_lines + self.precompute_lines: bool = precompute_lines + self.precomputed_lines: list[ModuloCircuitElement] = [] + self._precomputed_lines_generator = None - def write_p_and_q(self, input: list[PyFelt], precompute_consts: bool = True): + def _create_precomputed_lines_generator( + self, + ) -> Iterator[ + Tuple[ + Tuple[ModuloCircuitElement, ModuloCircuitElement], + Tuple[ModuloCircuitElement, ModuloCircuitElement], + ] + ]: + if self.precompute_lines: + if len(self.precomputed_lines) % 4 != 0: + raise ValueError( + "Number of precomputed line elements must be a multiple of 4." + ) + for i in range(0, len(self.precomputed_lines), 4): + yield ( + (self.precomputed_lines[i], self.precomputed_lines[i + 1]), + (self.precomputed_lines[i + 2], self.precomputed_lines[i + 3]), + ) + else: + while True: + yield ((None, None), (None, None)) + + def get_next_precomputed_line( + self, + ) -> Tuple[Tuple[PyFelt, PyFelt], Tuple[PyFelt, PyFelt]]: + return next(self._precomputed_lines_generator) + + def write_p_and_q(self, P: list[G1Point], Q: list[G2Point]): + assert len(P) == len(Q) == self.n_pairs + assert set([P[i].curve_id for i in range(len(P))]) == set( + [Q[i].curve_id for i in range(len(Q))] + ) + raw = [] + for P, Q in zip(P, Q): + raw.extend([self.field(P.x), self.field(P.y)]) + raw.extend([self.field(Q.x[0]), self.field(Q.x[1])]) + raw.extend([self.field(Q.y[0]), self.field(Q.y[1])]) + self.write_p_and_q_raw(raw) + return None + + def write_p_and_q_raw(self, input: list[PyFelt], precompute_consts: bool = True): assert ( len(input) == 6 * self.n_pairs ), f"Expected {6 * self.n_pairs} inputs, got {len(input)}" @@ -81,12 +130,14 @@ def write_p_and_q(self, input: list[PyFelt], precompute_consts: bool = True): self.precompute_consts() return None - def precompute_consts(self, n_pairs: int = None): + def precompute_consts(self, n_pairs: int = None, skip_P_precompute: bool = False): n_pairs = n_pairs or self.n_pairs - self.yInv = [self.inv(self.P[i][1]) for i in range(n_pairs)] - self.xNegOverY = [ - self.neg(self.div(self.P[i][0], self.P[i][1])) for i in range(n_pairs) - ] + if not skip_P_precompute: + self.yInv = [self.inv(self.P[i][1]) for i in range(n_pairs)] + self.xNegOverY = [ + self.neg(self.div(self.P[i][0], self.P[i][1])) for i in range(n_pairs) + ] + if -1 in self.loop_counter: self.Qneg = [ (self.Q[i][0], self.extf_neg(self.Q[i][1])) for i in range(n_pairs) @@ -144,7 +195,7 @@ def compute_adding_slope( den = self.extf_sub(Qa[0], Qb[0]) return self.fp2_div(num, den) - def build_sparse_line( + def build_sparse_line_eval( self, R0: tuple[ModuloCircuitElement, ModuloCircuitElement], R1: tuple[ModuloCircuitElement, ModuloCircuitElement], @@ -196,12 +247,14 @@ def build_sparse_line( else: raise NotImplementedError - def add_step( + def _add( self, Qa: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], Qb: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], k: int, ): + if self.precompute_lines and (k + 1) <= self.n_points_precomputed_lines: + return (None, None), self.get_next_precomputed_line() λ = self.compute_adding_slope(Qa, Qb) xr = self.extf_sub(X=self.fp2_square(X=λ), Y=self.extf_add(Qa[0], Qb[0])) yr = self.extf_sub( @@ -209,30 +262,55 @@ def add_step( Y=Qa[1], ) p = (xr, yr) - line = self.build_sparse_line( - R0=λ, # Directly store λ as R0 - R1=self.extf_sub(self.fp2_mul(λ, Qa[0]), Qa[1]), + lineR0 = λ + lineR1 = self.extf_sub(self.fp2_mul(λ, Qa[0]), Qa[1]) + + return p, (lineR0, lineR1) + + def add_step( + self, + Qa: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + Qb: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + k: int, + ): + p, (lineR0, lineR1) = self._add(Qa, Qb, k) + line = self.build_sparse_line_eval( + R0=lineR0, + R1=lineR1, yInv=self.yInv[k], xNegOverY=self.xNegOverY[k], ) return p, line - def line_compute( + def _line_compute( self, Qa: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], Qb: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], k: int, ): + if self.precompute_lines and (k + 1) <= self.n_points_precomputed_lines: + return self.get_next_precomputed_line() λ = self.compute_adding_slope(Qa, Qb) - line = self.build_sparse_line( - R0=λ, # Directly store λ as R0 - R1=self.extf_sub(self.fp2_mul(λ, Qa[0]), Qa[1]), + lineR0 = λ + lineR1 = self.extf_sub(self.fp2_mul(λ, Qa[0]), Qa[1]) + return lineR0, lineR1 + + def line_compute( + self, + Qa: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + Qb: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + k: int, + ): + lineR0, lineR1 = self._line_compute(Qa, Qb, k) + line = self.build_sparse_line_eval( + R0=lineR0, + R1=lineR1, yInv=self.yInv[k], xNegOverY=self.xNegOverY[k], ) return line - def double_step( + def _double( self, Q: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], k: int ): """ @@ -242,6 +320,8 @@ def double_step( :param p1: A tuple representing the point on the curve (x, y) in the extension field. :return: A tuple containing the doubled point and the line evaluation. """ + if self.precompute_lines and (k + 1) <= self.n_points_precomputed_lines: + return ((None, None), self.get_next_precomputed_line()) self.ops_counter["Double Step"] += 1 λ = self.compute_doubling_slope(Q) # Compute λ = 3x² / 2y @@ -253,22 +333,37 @@ def double_step( p = (xr, yr) - # Store the line evaluation for this doubling step - line = self.build_sparse_line( - R0=λ, # Directly store λ as R0 - R1=self.extf_sub(self.fp2_mul(λ, Q[0]), Q[1]), # Compute R1 as λ*x - y + lineR0 = λ + lineR1 = self.extf_sub(self.fp2_mul(λ, Q[0]), Q[1]) + + return p, (lineR0, lineR1) + + def double_step( + self, + Q: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + k: int, + ): + p, (lineR0, lineR1) = self._double(Q, k) + line = self.build_sparse_line_eval( + R0=lineR0, + R1=lineR1, yInv=self.yInv[k], xNegOverY=self.xNegOverY[k], ) - return p, line - def double_and_add_step( + def _double_and_add( self, Qa: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], Qb: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], k: int, ) -> list[ModuloCircuitElement]: + if self.precompute_lines and (k + 1) <= self.n_points_precomputed_lines: + return ( + (None, None), + self.get_next_precomputed_line(), + self.get_next_precomputed_line(), + ) self.ops_counter["Double-and-Add Step"] += 1 # Computes 2Qa+Qb as (Qa+Qb)+Qa # https://arxiv.org/pdf/math/0208038.pdf 3.1 @@ -279,15 +374,9 @@ def double_and_add_step( x3 = self.extf_sub(X=self.fp2_square(X=λ1), Y=self.extf_add(Qa[0], Qb[0])) # omit y3 computation + line1R0 = λ1 + line1R1 = self.extf_sub(self.fp2_mul(λ1, Qa[0]), Qa[1]) - line1 = self.build_sparse_line( - R0=λ1, - R1=self.extf_sub( - self.fp2_mul(λ1, Qa[0]), Qa[1] - ), # Compute R1 as λ1*x1 - y1 - yInv=self.yInv[k], - xNegOverY=self.xNegOverY[k], - ) # compute λ2 = -λ1-2y1/(x3-x1) num = self.extf_add(Qa[1], Qa[1]) @@ -300,20 +389,43 @@ def double_and_add_step( # compute y4 = λ2(x1 - x4)-y1 y4 = self.extf_sub(self.fp2_mul(λ2, self.extf_sub(Qa[0], x4)), Qa[1]) - line2 = self.build_sparse_line( - R0=λ2, - R1=self.extf_sub( - self.fp2_mul(λ2, Qa[0]), Qa[1] - ), # Compute R1 as λ2*x1 - y1 + line2R0 = λ2 + line2R1 = self.extf_sub(self.fp2_mul(λ2, Qa[0]), Qa[1]) + + return (x4, y4), (line1R0, line1R1), (line2R0, line2R1) + + def double_and_add_step( + self, + Qa: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + Qb: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], + k: int, + ): + (new_x, new_y), (line1R0, line1R1), (line2R0, line2R1) = self._double_and_add( + Qa, Qb, k + ) + line1 = self.build_sparse_line_eval( + R0=line1R0, + R1=line1R1, yInv=self.yInv[k], xNegOverY=self.xNegOverY[k], ) + line2 = self.build_sparse_line_eval( + R0=line2R0, + R1=line2R1, + yInv=self.yInv[k], + xNegOverY=self.xNegOverY[k], + ) + return (new_x, new_y), line1, line2 - return (x4, y4), line1, line2 - - def triple_step( + def _triple( self, Q: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], k: int ): + if self.precompute_lines and (k + 1) <= self.n_points_precomputed_lines: + return ( + (None, None), + self.get_next_precomputed_line(), + self.get_next_precomputed_line(), + ) self.ops_counter["Triple Step"] += 1 # Compute λ = 3x² / 2y. Manually to keep den = 2y to be re-used for λ2. x0, x1 = Q[0][0], Q[0][1] @@ -328,12 +440,9 @@ def triple_step( den = self.extf_add(Q[1], Q[1]) λ1 = self.fp2_div(num, den) - line1 = self.build_sparse_line( - R0=λ1, # Directly store λ as R0 - R1=self.extf_sub(self.fp2_mul(λ1, Q[0]), Q[1]), # Compute R1 as λ*x - y - yInv=self.yInv[k], - xNegOverY=self.xNegOverY[k], - ) + line1R0 = λ1 + line1R1 = self.extf_sub(self.fp2_mul(λ1, Q[0]), Q[1]) + # x2 = λ1^2 - 2x x2 = self.extf_sub(self.fp2_square(λ1), self.extf_add(Q[0], Q[0])) # ommit yr computation, and @@ -343,12 +452,8 @@ def triple_step( # It's coded as x - x2. λ2 = self.extf_sub(self.fp2_div(den, self.extf_sub(Q[0], x2)), λ1) - line2 = self.build_sparse_line( - R0=λ2, - R1=self.extf_sub(self.fp2_mul(λ2, Q[0]), Q[1]), # Compute R1 as λ2*x1 - y1 - yInv=self.yInv[k], - xNegOverY=self.xNegOverY[k], - ) + line2R0 = λ2 + line2R1 = self.extf_sub(self.fp2_mul(λ2, Q[0]), Q[1]) # // xr = λ²-p.x-x2 @@ -357,7 +462,26 @@ def triple_step( # // yr = λ(p.x-xr) - p.y yr = self.extf_sub(self.fp2_mul(λ2, self.extf_sub(Q[0], xr)), Q[1]) - return (xr, yr), line1, line2 + return (xr, yr), (line1R0, line1R1), (line2R0, line2R1) + + def triple_step( + self, Q: tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]], k: int + ): + (new_x, new_y), (line1R0, line1R1), (line2R0, line2R1) = self._triple(Q, k) + line1 = self.build_sparse_line_eval( + R0=line1R0, # Directly store λ as R0 + R1=line1R1, # Compute R1 as λ*x - y + yInv=self.yInv[k], + xNegOverY=self.xNegOverY[k], + ) + line2 = self.build_sparse_line_eval( + R0=line2R0, + R1=line2R1, + yInv=self.yInv[k], + xNegOverY=self.xNegOverY[k], + ) + + return (new_x, new_y), line1, line2 def bit_0_case( self, @@ -448,12 +572,10 @@ def bit_1_case( ) return new_f, new_points - def bn254_finalize_step( + def _bn254_finalize_step( self, Qs: list[tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]]], ): - q1s = [] - q2s = [] nr1p2 = [ self.set_or_get_constant( self.field( @@ -512,13 +634,36 @@ def bn254_finalize_step( nr2p3, ) - T, l1 = self.add_step(Qs[k], (q1x, q1y), k) - l2 = self.line_compute(T, (q2x, q2y), k) - new_lines.append(l1) - new_lines.append(l2) + T, (l1R0, l1R1) = self._add(Qs[k], (q1x, q1y), k) + l2R0, l2R1 = self._line_compute(T, (q2x, q2y), k) + new_lines.append(((l1R0, l1R1), (l2R0, l2R1))) return new_lines + def bn254_finalize_step( + self, + Qs: list[tuple[list[ModuloCircuitElement], list[ModuloCircuitElement]]], + ): + lines = self._bn254_finalize_step(Qs) + lines_evaluated = [] + for k, (l1, l2) in enumerate(lines): + line_eval = self.build_sparse_line_eval( + R0=l1[0], + R1=l1[1], + yInv=self.yInv[k], + xNegOverY=self.xNegOverY[k], + ) + line_eval2 = self.build_sparse_line_eval( + R0=l2[0], + R1=l2[1], + yInv=self.yInv[k], + xNegOverY=self.xNegOverY[k], + ) + lines_evaluated.append(line_eval) + lines_evaluated.append(line_eval2) + + return lines_evaluated + def miller_loop(self, n_pairs: int) -> list[ModuloCircuitElement]: f = [self.set_or_get_constant(1)] + [self.set_or_get_constant(0)] * 11 @@ -574,3 +719,98 @@ def miller_loop(self, n_pairs: int) -> list[ModuloCircuitElement]: raise NotImplementedError(f"Curve {self.curve_id} not implemented") return f + + +def precompute_lines(Qs: list[G2Point]) -> list[PyFelt]: + if len(Qs) == 0: + return [] + curve_id = Qs[0].curve_id.value + loop_counter = CURVES[curve_id].loop_counter + start_index = len(loop_counter) - 2 + n_pairs = len(Qs) + circuit = MultiMillerLoopCircuit( + name="precompute_helper", + curve_id=curve_id, + n_pairs=n_pairs, + hash_input=False, + precompute_lines=False, + ) + field = circuit.field + for Q in Qs: + circuit.Q.append( + ( + [ + circuit.write_element(field(Q.x[0])), + circuit.write_element(field(Q.x[1])), + ], + [ + circuit.write_element(field(Q.y[0])), + circuit.write_element(field(Q.y[1])), + ], + ) + ) + circuit.precompute_consts(skip_P_precompute=True) + + lines = [] + if loop_counter[start_index] == 1: + # Handle case when first bit is +1, need to triple point instead of double and add. + new_points = [] + new_lines = [] + for k in range(n_pairs): + T, l1, l2 = circuit._triple(circuit.Q[k], k) + new_lines.append(l1) + new_lines.append(l2) + new_points.append(T) + + elif loop_counter[start_index] == 0: + new_lines = [] + new_points = [] + for k in range(n_pairs): + T, l1 = circuit._double(circuit.Q[k], k) + new_lines.append(l1) + new_points.append(T) + else: + raise NotImplementedError( + f"Init bit {loop_counter[start_index]} not implemented" + ) + + points = new_points + lines.append(new_lines) + # Rest of miller loop.= + for i in range(start_index - 1, -1, -1): + + new_lines = [] + new_points = [] + if loop_counter[i] == 0: + for k in range(n_pairs): + T, l1 = circuit._double(points[k], k) + new_lines.append(l1) + new_points.append(T) + elif loop_counter[i] == 1 or loop_counter[i] == -1: + # Choose Q or -Q depending on the bit for the addition. + Q_selects = [ + circuit.Q[k] if loop_counter[i] == 1 else circuit.Qneg[k] + for k in range(n_pairs) + ] + for k in range(n_pairs): + T, l1, l2 = circuit._double_and_add(points[k], Q_selects[k], k) + new_lines.append(l1) + new_lines.append(l2) + new_points.append(T) + else: + raise NotImplementedError(f"Bit {loop_counter[i]} not implemented") + points = new_points + lines.append(new_lines) + + if curve_id == CurveID.BN254.value: + final_lines = circuit._bn254_finalize_step(points) + for l1, l2 in final_lines: + lines.append(l1) + lines.append(l2) + + elif curve_id == CurveID.BLS12_381.value: + pass + else: + raise NotImplementedError(f"Curve {curve_id} not implemented") + + return [x.felt for x in flatten(lines)] diff --git a/hydra/precompiled_circuits/multi_pairing_check.py b/hydra/precompiled_circuits/multi_pairing_check.py index 4be2e849..7f510d3e 100644 --- a/hydra/precompiled_circuits/multi_pairing_check.py +++ b/hydra/precompiled_circuits/multi_pairing_check.py @@ -34,13 +34,7 @@ def get_root_and_scaling_factor( c_input: list[PyFelt] = [] if isinstance(P[0], G1Point): - for p, q in zip(P, Q): - c_input.append(field(p.x)) - c_input.append(field(p.y)) - c_input.append(field(q.x[0])) - c_input.append(field(q.x[1])) - c_input.append(field(q.y[0])) - c_input.append(field(q.y[1])) + c.write_p_and_q(P, Q) elif isinstance(P[0], tuple) and isinstance(P[0][0], ModuloCircuitElement): for p, q in zip(P, Q): c_input.append(p[0].felt) @@ -49,11 +43,12 @@ def get_root_and_scaling_factor( c_input.append(q[0][1].felt) c_input.append(q[1][0].felt) c_input.append(q[1][1].felt) + c.write_p_and_q_raw(c_input) c: MultiMillerLoopCircuit = MultiMillerLoopCircuit( name="mock", curve_id=curve_id, n_pairs=len(P) ) - c.write_p_and_q(c_input) + f = E12.from_direct(c.miller_loop(len(P)), curve_id) if m is not None: M = E12.from_direct(m, curve_id) @@ -437,7 +432,7 @@ def get_pairing_check_input( mloop_circuit = MultiMillerLoopCircuit( name="mock", curve_id=curve_id.value, n_pairs=1 ) - mloop_circuit.write_p_and_q(c_input[-6:]) + mloop_circuit.write_p_and_q_raw(c_input[-6:]) M = mloop_circuit.miller_loop(n_pairs=1) M = [mi.felt for mi in M] return c_input[:-6], M @@ -454,7 +449,7 @@ def test_mpcheck(curve_id: CurveID, n_pairs: int, include_m: bool = False): circuit_input, m = get_pairing_check_input( curve_id, n_pairs, include_m=include_m ) - c.write_p_and_q(circuit_input) + c.write_p_and_q_raw(circuit_input) M = c.write_elements(m, WriteOps.INPUT) if m is not None else None c.multi_pairing_check(n_pairs, M) c.finalize_circuit() diff --git a/tests/benchmarks.py b/tests/benchmarks.py index c675254d..21f4de81 100644 --- a/tests/benchmarks.py +++ b/tests/benchmarks.py @@ -152,7 +152,7 @@ def test_miller_n(curve_id, n): pairs.extend(pair) c = MultiMillerLoopCircuit(f"Miller n={n} {curve_id.name}", curve_id.value, n) - c.write_p_and_q([field(x) for x in pairs]) + c.write_p_and_q_raw([field(x) for x in pairs]) f = c.miller_loop(n_pairs=n) diff --git a/tests/hydra/circuits/test_multi_miller_loop.py b/tests/hydra/circuits/test_multi_miller_loop.py new file mode 100644 index 00000000..b7442fdf --- /dev/null +++ b/tests/hydra/circuits/test_multi_miller_loop.py @@ -0,0 +1,115 @@ +from hydra.precompiled_circuits.multi_miller_loop import ( + MultiMillerLoopCircuit, + precompute_lines, +) +from hydra.modulo_circuit import WriteOps + +from hydra.definitions import CurveID, G1Point, G2Point +from hydra.hints.extf_mul import nondeterministic_extension_field_mul_divmod + +import pytest +import copy + + +@pytest.fixture( + params=[ + (CurveID.BN254, 1), + (CurveID.BLS12_381, 1), + (CurveID.BN254, 2), + (CurveID.BLS12_381, 2), + (CurveID.BN254, 3), + (CurveID.BLS12_381, 3), + (CurveID.BN254, 4), + (CurveID.BLS12_381, 4), + ] +) +def circuit_and_points( + request, +) -> tuple[MultiMillerLoopCircuit, list[G1Point], list[G2Point]]: + curve_id, n_pairs = request.param + + def init_miller_loop_circuit() -> ( + tuple[MultiMillerLoopCircuit, list[G1Point], list[G2Point]] + ): + circuit = MultiMillerLoopCircuit( + name="test", curve_id=curve_id.value, n_pairs=n_pairs, hash_input=False + ) + Ps = [G1Point.gen_random_point(curve_id) for _ in range(n_pairs)] + Qs = [G2Point.gen_random_point(curve_id) for _ in range(n_pairs)] + circuit.write_p_and_q(Ps, Qs) + return circuit, Ps, Qs + + return init_miller_loop_circuit() + + +def test_precomputed_and_without_precompute_gives_same_output( + circuit_and_points: tuple[MultiMillerLoopCircuit, list[G1Point], list[G2Point]] +): + circuit0, Ps, Qs = circuit_and_points + n_pairs = len(Ps) + + f0 = circuit0.miller_loop(n_pairs) + f0 = [fi.felt for fi in f0] + + circuit1 = copy.deepcopy(circuit0) + circuit1.precompute_lines = True + circuit1.precomputed_lines = circuit1.write_elements( + precompute_lines(Qs), WriteOps.INPUT + ) + circuit1._precomputed_lines_generator = ( + circuit1._create_precomputed_lines_generator() + ) + circuit1.n_points_precomputed_lines = n_pairs + + f1 = circuit1.miller_loop(n_pairs) + f1 = [fi.felt for fi in f1] + + assert f0 == f1 + + +def test_partially_precomputed_and_without_precompute_gives_same_output( + circuit_and_points: tuple[MultiMillerLoopCircuit, list[G1Point], list[G2Point]] +): + circuit0, Ps, Qs = circuit_and_points + n_pairs = len(Ps) + f0 = circuit0.miller_loop(n_pairs) + f0 = [fi.felt for fi in f0] + + circuit1 = copy.deepcopy(circuit0) + circuit1.precompute_lines = True + circuit1.precomputed_lines = circuit1.write_elements( + precompute_lines(Qs[: n_pairs // 2]), WriteOps.INPUT + ) + circuit1._precomputed_lines_generator = ( + circuit1._create_precomputed_lines_generator() + ) + circuit1.n_points_precomputed_lines = n_pairs // 2 + + f1 = circuit1.miller_loop(n_pairs) + f1 = [fi.felt for fi in f1] + + assert f0 == f1 + + +def test_prod_miller_loop_equals_multi_miller_loop( + circuit_and_points: tuple[MultiMillerLoopCircuit, list[G1Point], list[G2Point]] +): + circuit_multi, Ps, Qs = circuit_and_points + curve_id = circuit_multi.curve_id + n_pairs = len(Ps) + f0 = circuit_multi.miller_loop(n_pairs) + f0 = [fi.felt for fi in f0] + + fis = [] + for i in range(n_pairs): + circuit_prod_i = MultiMillerLoopCircuit( + name="test", curve_id=curve_id, n_pairs=1, hash_input=False + ) + circuit_prod_i.write_p_and_q(Ps[i : i + 1], Qs[i : i + 1]) + f_i = circuit_prod_i.miller_loop(1) + + fis.append(f_i) + + _, R = nondeterministic_extension_field_mul_divmod(fis, curve_id, 12) + + assert f0 == R diff --git a/tools/starknet_cli.py b/tools/starknet_cli.py index 024b7f80..e4fa90be 100644 --- a/tools/starknet_cli.py +++ b/tools/starknet_cli.py @@ -162,7 +162,7 @@ def multi_pairing_check_calldata( circuit = MultiMillerLoopCircuit( name="precompute M", curve_id=curve_id.value, n_pairs=1 ) - circuit.write_p_and_q(extra_pair.to_pyfelt_list()) + circuit.write_p_and_q_raw(extra_pair.to_pyfelt_list()) M = circuit.miller_loop(n_pairs=1) M = [mi.felt for mi in M] else: @@ -177,7 +177,7 @@ def multi_pairing_check_calldata( p_q_input = [] for pair in pairs: p_q_input.extend(pair.to_pyfelt_list()) - mpcheck_circuit.write_p_and_q(p_q_input) + mpcheck_circuit.write_p_and_q_raw(p_q_input) _, lambda_root, lambda_root_inverse, scaling_factor, scaling_factor_sparsity = ( mpcheck_circuit.multi_pairing_check(len(pairs), M) )