diff --git a/src/definitions.cairo b/src/definitions.cairo index ea91f4c2..058cc1ef 100644 --- a/src/definitions.cairo +++ b/src/definitions.cairo @@ -1,18 +1,18 @@ from starkware.cairo.common.cairo_builtins import UInt384 -from src.precompiled_circuits.final_exp import BN254_final_exp, BLS12_381_final_exp namespace bls { const CURVE_ID = 'bls12_381'; + // p = 0x1A0111EA397FE69A4B1BA7B6434BACD764774B84F38512BF6730D2A0F6B0F6241EABFFFEB153FFFFB9FEFFFFFFFFAAAB const P0 = 0xb153ffffb9feffffffffaaab; const P1 = 0x6730d2a0f6b0f6241eabfffe; const P2 = 0x434bacd764774b84f38512bf; const P3 = 0x1a0111ea397fe69a4b1ba7b6; // The following constants represent the size of the curve: - // const n = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 - const N0 = 0x3e5bfeffffffff00000001; - const N1 = 0x2020268760154ef6900bff; - const N2 = 0x73eda753299d7d483339d; + // n = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 + const N0 = 0xfffe5bfeffffffff00000001; + const N1 = 0x3339d80809a1d80553bda402; + const N2 = 0x73eda753299d7d48; // Non residue constants: const NON_RESIDUE_E2_a0 = 1; @@ -21,16 +21,17 @@ namespace bls { namespace bn { const CURVE_ID = 'bn254'; - const P0 = 60193888514187762220203335; - const P1 = 27625954992973055882053025; - const P2 = 3656382694611191768777988; + // p = 0x30644E72E131A029B85045B68181585D97816A916871CA8D3C208C16D87CFD47 + const P0 = 0x6871ca8d3c208c16d87cfd47; + const P1 = 0xb85045b68181585d97816a91; + const P2 = 0x30644e72e131a029; // The following constants represent the size of the curve: // n = n(u) = 36u^4 + 36u^3 + 18u^2 + 6u + 1 - // const n = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 - const N0 = 0x39709143e1f593f0000001; - const N1 = 0x16da06056174a0cfa121e6; - const N2 = 0x30644e72e131a029b8504; + // n = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 + const N0 = 0x79b9709143e1f593f0000001; + const N1 = 0xb85045b68181585d2833e848; + const N2 = 0x30644e72e131a029; // Non residue constants: const NON_RESIDUE_E2_a0 = 9; @@ -49,70 +50,52 @@ func get_P(curve_id: felt) -> (prime: UInt384) { } } -// func get_final_exp_circuit(curve_id: felt) -> ( -// constants_ptr: felt*, -// add_offsets_ptr: felt*, -// mul_offsets_ptr: felt*, -// left_assert_eq_offsets_ptr: felt*, -// right_assert_eq_offsets_ptr: felt*, -// poseidon_indexes_ptr: felt*, -// constants_ptr_len: felt, -// add_mod_n: felt, -// mul_mod_n: felt, -// commitments_len: felt, -// assert_eq_len: felt, -// N_Euclidean_equations: felt, -// ) { -// if (curve_id == bls.CURVE_ID) { -// return ( -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// 0, -// 0, -// 0, -// 0, -// 0, -// 0, -// ); -// } else { -// if (curve_id == bn.CURVE_ID) { -// return get_GaragaBN254FinalExp_non_interactive_circuit(); -// } else { -// return ( -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// cast(0, felt*), -// 0, -// 0, -// 0, -// 0, -// 0, -// 0, -// ); -// } -// } -// } +const SUPPORTED_CURVE_ID = 0; +const UNSUPPORTED_CURVE_ID = 1; + +func is_curve_id_supported(curve_id: felt) -> (res: felt) { + if (curve_id == bls.CURVE_ID) { + return (res=SUPPORTED_CURVE_ID); + } else { + if (curve_id == bn.CURVE_ID) { + return (res=SUPPORTED_CURVE_ID); + } else { + return (res=UNSUPPORTED_CURVE_ID); + } + } +} // Base for UInt384 / BigInt4 const BASE = 2 ** 96; +const BASE_DIV_2 = 2 ** 95; const N_LIMBS = 4; -const STARK_MIN_ONE_D2 = 576460752303423505; +const STARK_MIN_ONE_D2 = 0x800000000000011; + +struct G1Point { + x: UInt384, + y: UInt384, +} + +struct G2Point { + x0: UInt384, + x1: UInt384, + y0: UInt384, + y1: UInt384, +} + +struct G1G2Pair { + P: G1Point, + Q: G2Point, +} struct E6D { - w0: UInt384, - w1: UInt384, - w2: UInt384, - w3: UInt384, - w4: UInt384, - w5: UInt384, + v0: UInt384, + v1: UInt384, + v2: UInt384, + v3: UInt384, + v4: UInt384, + v5: UInt384, } struct E12D { @@ -130,48 +113,239 @@ struct E12D { w11: UInt384, } -struct ExtFCircuitInfo { - constants_ptr: felt*, - constants_ptr_len: felt, - add_offsets: felt*, - mul_offsets: felt*, - commitments_len: felt, - transcript_indexes: felt*, - N_Euclidean_equations: felt, -} - -func zero_E12D() -> E12D { - let res = E12D( - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), +func zero_E12D() -> (res: E12D) { + return ( + res=E12D( + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + ), + ); +} + +func one_E12D() -> (res: E12D) { + return ( + res=E12D( + UInt384(1, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + ), + ); +} + +func one_E6D() -> (res: E6D) { + return ( + res=E6D( + UInt384(1, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + UInt384(0, 0, 0, 0), + ), + ); +} + +struct UnreducedBigInt7 { + d0: felt, + d1: felt, + d2: felt, + d3: felt, + d4: felt, + d5: felt, + d6: felt, +} + +func UInt384_mul(x: UInt384, y: UInt384) -> (res: UnreducedBigInt7) { + return ( + UnreducedBigInt7( + d0=x.d0 * y.d0, + d1=x.d0 * y.d1 + x.d1 * y.d0, + d2=x.d0 * y.d2 + x.d1 * y.d1 + x.d2 * y.d0, + d3=x.d0 * y.d3 + x.d1 * y.d2 + x.d2 * y.d1 + x.d3 * y.d0, + d4=x.d1 * y.d3 + x.d2 * y.d2 + x.d3 * y.d1, + d5=x.d2 * y.d3 + x.d3 * y.d2, + d6=x.d3 * y.d3, + ), + ); +} + +func is_zero_mod_P{range_check_ptr, range_check96_ptr: felt*}(x: UInt384, p: UInt384) -> ( + res: felt +) { + alloc_locals; + %{ + from src.hints.io import bigint_pack + x = bigint_pack(ids.x, 4, 2**96) + p = bigint_pack(ids.p, 4, 2**96) + x=x%p + %} + if (nondet %{ x == 0 %} != 0) { + verify_zero4(x, p); + return (res=1); + } + + local x_inv: UInt384; + %{ + from src.hints.io import bigint_fill + bigint_fill(pow(x, -1, p), ids.x_inv, ids.N_LIMBS, ids.BASE) + %} + assert [range_check96_ptr] = x_inv.d0; + assert [range_check96_ptr + 1] = x_inv.d1; + assert [range_check96_ptr + 2] = x_inv.d2; + assert [range_check96_ptr + 3] = x_inv.d3; + + tempvar range_check96_ptr = range_check96_ptr + 4; + + let (x_x_inv) = UInt384_mul(x, x_inv); + + // Check that x * x_inv = 1 to verify that x != 0. + verify_zero7( + UnreducedBigInt7( + d0=x_x_inv.d0 - 1, + d1=x_x_inv.d1, + d2=x_x_inv.d2, + d3=x_x_inv.d3, + d4=x_x_inv.d4, + d5=x_x_inv.d5, + d6=x_x_inv.d6, + ), + p, ); - return res; -} - -func one_E12D() -> E12D { - let res = E12D( - UInt384(0, 0, 0, 0), - UInt384(1, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), - UInt384(0, 0, 0, 0), + return (res=0); +} + +func verify_zero4{range_check_ptr, range_check96_ptr: felt*}(x: UInt384, p: UInt384) { + alloc_locals; + local q: felt; + %{ + from src.hints.io import bigint_pack + + x = bigint_pack(ids.x, 4, 2**96) + p = bigint_pack(ids.p, 4, 2**96) + + q, r = divmod(x, p) + assert r == 0, f"verify_zero: Invalid input." + ids.q=q + %} + + assert [range_check96_ptr] = q; + + tempvar carry1 = (q * p.d0 - x.d0) / BASE; + assert [range_check_ptr] = carry1 + 2 ** 127; + + tempvar carry2 = (q * p.d1 - x.d1 + carry1) / BASE; + assert [range_check_ptr + 1] = carry2 + 2 ** 127; + + tempvar carry3 = (q * p.d2 - x.d2 + carry2) / BASE; + assert [range_check_ptr + 2] = carry3 + 2 ** 127; + + assert q * p.d3 - x.d3 + carry3 = 0; + + let range_check_ptr = range_check_ptr + 3; + let range_check96_ptr = range_check96_ptr + 1; + + return (); +} + +func verify_zero7{range_check_ptr, range_check96_ptr: felt*}(val: UnreducedBigInt7, p: UInt384) { + alloc_locals; + local q: UInt384; + %{ + from src.hints.io import bigint_pack, bigint_fill + + val = bigint_pack(ids.val, 7, 2**96) + p = bigint_pack(ids.p, 4, 2**96) + + q, r = divmod(val, p) + assert r == 0, f"verify_zero: Invalid input {val%p}." + bigint_fill(q, ids.q, ids.N_LIMBS, ids.BASE) + %} + + assert [range_check96_ptr] = q.d0; + assert [range_check96_ptr + 1] = q.d1; + assert [range_check96_ptr + 2] = q.d2; + assert [range_check96_ptr + 3] = q.d3; + + tempvar q_P: UnreducedBigInt7 = UnreducedBigInt7( + d0=q.d0 * p.d0, + d1=q.d0 * p.d1 + q.d1 * p.d0, + d2=q.d0 * p.d2 + q.d1 * p.d1 + q.d2 * p.d0, + d3=q.d0 * p.d3 + q.d1 * p.d2 + q.d2 * p.d1 + q.d3 * p.d0, + d4=q.d1 * p.d3 + q.d2 * p.d2 + q.d3 * p.d1, + d5=q.d2 * p.d3 + q.d3 * p.d2, + d6=q.d3 * p.d3, ); - return res; + + tempvar carry1 = (q_P.d0 - val.d0) / BASE; + assert [range_check_ptr + 0] = carry1 + 2 ** 127; + + tempvar carry2 = (q_P.d1 - val.d1 + carry1) / BASE; + assert [range_check_ptr + 1] = carry2 + 2 ** 127; + + tempvar carry3 = (q_P.d2 - val.d2 + carry2) / BASE; + assert [range_check_ptr + 2] = carry3 + 2 ** 127; + + tempvar carry4 = (q_P.d3 - val.d3 + carry3) / BASE; + assert [range_check_ptr + 3] = carry4 + 2 ** 127; + + tempvar carry5 = (q_P.d4 - val.d4 + carry4) / BASE; + assert [range_check_ptr + 4] = carry5 + 2 ** 127; + + tempvar carry6 = (q_P.d5 - val.d5 + carry5) / BASE; + assert [range_check_ptr + 5] = carry6 + 2 ** 127; + + assert q_P.d6 - val.d6 + carry6 = 0; + + tempvar range_check_ptr = range_check_ptr + 6; + tempvar range_check96_ptr = range_check96_ptr + 4; + return (); +} + +func is_zero_E6D{range_check_ptr, range_check96_ptr: felt*}(x: E6D, curve_id: felt) -> (res: felt) { + alloc_locals; + let (P) = get_P(curve_id); + let (is_zero_v0) = is_zero_mod_P(x.v0, P); + let (is_zero_v1) = is_zero_mod_P(x.v1, P); + let (is_zero_v2) = is_zero_mod_P(x.v2, P); + let (is_zero_v3) = is_zero_mod_P(x.v3, P); + let (is_zero_v4) = is_zero_mod_P(x.v4, P); + let (is_zero_v5) = is_zero_mod_P(x.v5, P); + if (is_zero_v0 == 0) { + return (res=0); + } + if (is_zero_v1 == 0) { + return (res=0); + } + if (is_zero_v2 == 0) { + return (res=0); + } + if (is_zero_v3 == 0) { + return (res=0); + } + if (is_zero_v4 == 0) { + return (res=0); + } + if (is_zero_v5 == 0) { + return (res=0); + } + return (res=1); } diff --git a/src/ec_ops.cairo b/src/ec_ops.cairo index e69de29b..e51c416a 100644 --- a/src/ec_ops.cairo +++ b/src/ec_ops.cairo @@ -0,0 +1,28 @@ +from src.definitions import is_zero_mod_P, get_P +from src.precompiled_circuits.ec import get_IS_ON_CURVE_G1_G2_circuit +from src.modulo_circuit import run_modulo_circuit, ModuloCircuit +from starkware.cairo.common.cairo_builtins import ModBuiltin, UInt384 + +func is_on_curve_g1_g2{ + range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* +}(curve_id: felt, input: felt*) -> (res: felt) { + alloc_locals; + let (P) = get_P(curve_id); + + let (circuit) = get_IS_ON_CURVE_G1_G2_circuit(curve_id); + let (output: felt*) = run_modulo_circuit(circuit, input); + let (check_g1: felt) = is_zero_mod_P([cast(output, UInt384*)], P); + let (check_g20: felt) = is_zero_mod_P([cast(output + UInt384.SIZE, UInt384*)], P); + let (check_g21: felt) = is_zero_mod_P([cast(output + 2 * UInt384.SIZE, UInt384*)], P); + + if (check_g1 == 0) { + return (res=0); + } + if (check_g20 == 0) { + return (res=0); + } + if (check_g21 == 0) { + return (res=0); + } + return (res=1); +} diff --git a/src/modulo_circuit.cairo b/src/modulo_circuit.cairo index 65533414..7fa6e55e 100644 --- a/src/modulo_circuit.cairo +++ b/src/modulo_circuit.cairo @@ -3,86 +3,282 @@ from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.modulo import run_mod_p_circuit -from src.definitions import get_P, BASE, N_LIMBS -from src.precompiled_circuits.sample import get_sample_circuit -from src.utils import ( - get_Z_and_RLC_from_transcript, - write_felts_to_value_segment, - assert_limbs_at_index_are_equal, -) +from src.definitions import get_P, BASE, N_LIMBS, is_curve_id_supported, SUPPORTED_CURVE_ID +from src.utils import get_Z_and_RLC_from_transcript, write_felts_to_value_segment, retrieve_output + +struct ExtensionFieldModuloCircuit { + constants_ptr: felt*, + add_offsets_ptr: felt*, + mul_offsets_ptr: felt*, + output_offsets_ptr: felt*, + poseidon_indexes_ptr: felt*, + constants_ptr_len: felt, + input_len: felt, + commitments_len: felt, + witnesses_len: felt, + output_len: felt, + continuous_output: felt, + add_mod_n: felt, + mul_mod_n: felt, + n_assert_eq: felt, + N_Euclidean_equations: felt, + name: felt, + curve_id: felt, +} + +struct ModuloCircuit { + constants_ptr: felt*, + add_offsets_ptr: felt*, + mul_offsets_ptr: felt*, + output_offsets_ptr: felt*, + constants_ptr_len: felt, + witnesses_len: felt, + input_len: felt, + output_len: felt, + continuous_output: felt, + add_mod_n: felt, + mul_mod_n: felt, + n_assert_eq: felt, + name: felt, + curve_id: felt, +} + +func get_void_modulo_circuit() -> (circuit: ModuloCircuit*) { + return (cast(-1, ModuloCircuit*),); +} +func get_void_extension_field_modulo_circuit() -> (circuit: ExtensionFieldModuloCircuit*) { + return (cast(-1, ExtensionFieldModuloCircuit*),); +} + +func run_modulo_circuit{ + range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* +}(circuit: ModuloCircuit*, input: felt*) -> (output: felt*) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let p: UInt384 = get_P(circuit.curve_id); + local values_ptr: UInt384* = cast(range_check96_ptr, UInt384*); + memcpy( + dst=range_check96_ptr, src=circuit.constants_ptr, len=circuit.constants_ptr_len * N_LIMBS + ); // write(Constants) + memcpy( + dst=range_check96_ptr + circuit.constants_ptr_len * N_LIMBS, + src=input, + len=circuit.input_len, + ); // write(Input) + + %{ + from src.precompiled_circuits.all_circuits import ALL_EXTF_CIRCUITS, CircuitID + from src.hints.io import pack_bigint_ptr, fill_felt_ptr, flatten + from src.definitions import CURVES, PyFelt + p = CURVES[ids.circuit.curve_id].p + circuit_input = pack_bigint_ptr(memory, ids.input, ids.N_LIMBS, ids.BASE, ids.circuit.input_len//ids.N_LIMBS) + MOD_CIRCUIT = ALL_EXTF_CIRCUITS[CircuitID(ids.circuit.name)](ids.circuit.curve_id, auto_run=False) + MOD_CIRCUIT = MOD_CIRCUIT.run_circuit(circuit_input) + + witnesses = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in MOD_CIRCUIT.witnesses]) + + fill_felt_ptr(x=witnesses, memory=memory, address=ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len) + #MOD_CIRCUIT.print_value_segment() + %} + + run_mod_p_circuit( + p=p, + values_ptr=values_ptr, + add_mod_offsets_ptr=circuit.add_offsets_ptr, + add_mod_n=circuit.add_mod_n, + mul_mod_offsets_ptr=circuit.mul_offsets_ptr, + mul_mod_n=circuit.mul_mod_n, + ); + + tempvar range_check96_ptr = range_check96_ptr + circuit.input_len + circuit.witnesses_len + ( + circuit.constants_ptr_len + circuit.add_mod_n + circuit.mul_mod_n - circuit.n_assert_eq + ) * N_LIMBS; + + let (output: felt*) = retrieve_output( + values_segment=values_ptr, + output_offsets_ptr=circuit.output_offsets_ptr, + n=circuit.output_len, + continuous_output=circuit.continuous_output, + ); + return (output=output); +} + func run_extension_field_modulo_circuit{ range_check_ptr, poseidon_ptr: PoseidonBuiltin*, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*, -}(input: felt*, input_len: felt, curve_id: felt, circuit_id: felt) -> felt { +}(circuit: ExtensionFieldModuloCircuit*, input: felt*) -> (output: felt*, Z: felt) { alloc_locals; let (__fp__, _) = get_fp_and_pc(); - let p: UInt384 = get_P(curve_id); - let ( - constants_ptr: felt*, - add_offsets_ptr: felt*, - mul_offsets_ptr: felt*, - left_assert_eq_offsets_ptr: felt*, - right_assert_eq_offsets_ptr: felt*, - poseidon_indexes_ptr: felt*, - constants_ptr_len: felt, - add_mod_n: felt, - mul_mod_n: felt, - commitments_len: felt, - assert_eq_len: felt, - N_Euclidean_equations: felt, - ) = get_sample_circuit(circuit_id); + let (status) = is_curve_id_supported(circuit.curve_id); + if (status != SUPPORTED_CURVE_ID) { + return (cast(-1, felt*), 0); + } + let p: UInt384 = get_P(circuit.curve_id); local values_ptr: UInt384* = cast(range_check96_ptr, UInt384*); - memcpy(dst=range_check96_ptr, src=constants_ptr, len=constants_ptr_len * N_LIMBS); // write(Constants) - memcpy(dst=range_check96_ptr + constants_ptr_len * N_LIMBS, src=input, len=input_len); // write(Input) + memcpy( + dst=range_check96_ptr, src=circuit.constants_ptr, len=circuit.constants_ptr_len * N_LIMBS + ); // write(Constants) + + memcpy( + dst=range_check96_ptr + circuit.constants_ptr_len * N_LIMBS, + src=input, + len=circuit.input_len, + ); // write(Input) - local commitments: felt*; %{ - from src.precompiled_circuits.sample import get_sample_circuit - from src.hints.io import pack_bigint_ptr, flatten - from src.definitions import CURVES, PyFelt - p = CURVES[ids.curve_id].p - circuit_input = pack_bigint_ptr(memory, ids.input, ids.N_LIMBS, ids.BASE, ids.input_len//ids.N_LIMBS) - circuit_input = [PyFelt(x, p) for x in circuit_input] - EXTF_MOD_CIRCUIT = get_sample_circuit(ids.circuit_id, circuit_input) + from src.precompiled_circuits.all_circuits import ALL_EXTF_CIRCUITS, CircuitID + from src.hints.io import bigint_split, pack_bigint_ptr, fill_felt_ptr, flatten + circuit_input = pack_bigint_ptr(memory, ids.input, ids.N_LIMBS, ids.BASE, ids.circuit.input_len//ids.N_LIMBS) + EXTF_MOD_CIRCUIT = ALL_EXTF_CIRCUITS[CircuitID(ids.circuit.name)](ids.circuit.curve_id, auto_run=False) + + EXTF_MOD_CIRCUIT = EXTF_MOD_CIRCUIT.run_circuit(circuit_input) + print(f"\t{ids.circuit.constants_ptr_len} Constants and {ids.circuit.input_len//4} Inputs copied to RC_96 memory segment at position {ids.range_check96_ptr}") + commitments = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in EXTF_MOD_CIRCUIT.commitments]) - ids.commitments = segments.gen_arg(commitments) - print(len(commitments), len(commitments)//4) + witnesses = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in EXTF_MOD_CIRCUIT.witnesses]) + fill_felt_ptr(x=commitments, memory=memory, address=ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len) + fill_felt_ptr(x=witnesses, memory=memory, address=ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len + ids.circuit.commitments_len) + print(f"\t{len(commitments)//4} Commitments & {len(witnesses)//4} witnesses computed and filled in RC_96 memory segment at positions {ids.range_check96_ptr+ids.circuit.constants_ptr_len * ids.N_LIMBS+ids.circuit.input_len} and {ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len + ids.circuit.commitments_len}") + + #EXTF_MOD_CIRCUIT.print_value_segment() %} + let (local Z: felt, local RLC_coeffs: felt*) = get_Z_and_RLC_from_transcript( + transcript_start=cast(values_ptr, felt*) + circuit.constants_ptr_len * N_LIMBS, + poseidon_indexes_ptr=circuit.poseidon_indexes_ptr, + n_elements_in_transcript=(circuit.commitments_len + circuit.input_len) / N_LIMBS, + n_equations=circuit.N_Euclidean_equations, + init_hash=circuit.name, + ); + %{ print(f"\tZ = Hash(Input|Commitments) = Poseidon({(ids.circuit.input_len+ids.circuit.commitments_len)//ids.N_LIMBS} * [Uint384]) computed") %} + %{ print(f"\tN={ids.circuit.N_Euclidean_equations} felt252 from Poseidon transcript retrieved.") %} + + %{ + # Sanity Check : + assert ids.Z == EXTF_MOD_CIRCUIT.transcript.continuable_hash, f"Z for circuit {EXTF_MOD_CIRCUIT.name} does not match {hex(ids.Z)} {hex(EXTF_MOD_CIRCUIT.transcript.continuable_hash)}" + %} + + tempvar range_check96_ptr = range_check96_ptr + circuit.constants_ptr_len * N_LIMBS + + circuit.input_len + circuit.commitments_len + circuit.witnesses_len; + + write_felts_to_value_segment(values_start=&Z, n=1); + write_felts_to_value_segment(values_start=RLC_coeffs, n=circuit.N_Euclidean_equations); + %{ print(f"\tZ and felt252 written to value segment") %} + %{ print(f"\tRunning ModuloBuiltin circuit...") %} + run_mod_p_circuit( + p=p, + values_ptr=values_ptr, + add_mod_offsets_ptr=circuit.add_offsets_ptr, + add_mod_n=circuit.add_mod_n, + mul_mod_offsets_ptr=circuit.mul_offsets_ptr, + mul_mod_n=circuit.mul_mod_n, + ); + + tempvar range_check96_ptr = range_check96_ptr + ( + circuit.add_mod_n + circuit.mul_mod_n - circuit.n_assert_eq + ) * N_LIMBS; + + let (output: felt*) = retrieve_output( + values_segment=values_ptr, + output_offsets_ptr=circuit.output_offsets_ptr, + n=circuit.output_len, + continuous_output=circuit.continuous_output, + ); + return (output=output, Z=Z); +} + +// Same as run_modulo_circuit, but doesen't hash the inputs and starts with +// an initial hash. +func run_extension_field_modulo_circuit_continuation{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}(circuit: ExtensionFieldModuloCircuit*, input: felt*, init_hash: felt) -> ( + output: felt*, Z: felt +) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let (status) = is_curve_id_supported(circuit.curve_id); + if (status != SUPPORTED_CURVE_ID) { + return (cast(-1, felt*), 0); + } + let p: UInt384 = get_P(circuit.curve_id); + + local values_ptr: UInt384* = cast(range_check96_ptr, UInt384*); memcpy( - dst=range_check96_ptr + constants_ptr_len * N_LIMBS + input_len, - src=commitments, - len=commitments_len * N_LIMBS, - ); // write(Commitments) + dst=range_check96_ptr, src=circuit.constants_ptr, len=circuit.constants_ptr_len * N_LIMBS + ); // write(Constants) + + memcpy( + dst=range_check96_ptr + circuit.constants_ptr_len * N_LIMBS, + src=input, + len=circuit.input_len, + ); // write(Input) + + %{ + from src.precompiled_circuits.all_circuits import ALL_EXTF_CIRCUITS, CircuitID + from src.hints.io import bigint_split, pack_bigint_ptr, fill_felt_ptr, flatten + circuit_input = pack_bigint_ptr(memory, ids.input, ids.N_LIMBS, ids.BASE, ids.circuit.input_len//ids.N_LIMBS) + EXTF_MOD_CIRCUIT = ALL_EXTF_CIRCUITS[CircuitID(ids.circuit.name)](ids.circuit.curve_id, auto_run=False, init_hash=ids.init_hash) + + EXTF_MOD_CIRCUIT = EXTF_MOD_CIRCUIT.run_circuit(input=circuit_input) + print(f"\t{ids.circuit.constants_ptr_len} Constants and {ids.circuit.input_len//4} Inputs copied to RC_96 memory segment at position {ids.range_check96_ptr}") + + commitments = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in EXTF_MOD_CIRCUIT.commitments]) + witnesses = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in EXTF_MOD_CIRCUIT.witnesses]) + fill_felt_ptr(x=commitments, memory=memory, address=ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len) + fill_felt_ptr(x=witnesses, memory=memory, address=ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len + ids.circuit.commitments_len) + # print(f"continuation segment:, init_hash={hex(ids.init_hash)}") + #EXTF_MOD_CIRCUIT.print_value_segment() + print(f"\t{len(commitments)//4} Commitments & {len(witnesses)//4} witnesses computed and filled in RC_96 memory segment at positions {ids.range_check96_ptr+ids.circuit.constants_ptr_len * ids.N_LIMBS+ids.circuit.input_len} and {ids.range_check96_ptr + ids.circuit.constants_ptr_len * ids.N_LIMBS + ids.circuit.input_len + ids.circuit.commitments_len}") + %} + + %{ print(f"\tZ = Hash(Init_Hash|Commitments) = Poseidon(Init_Hash, Poseidon({(ids.circuit.commitments_len)//ids.N_LIMBS} * [Uint384])) computed") %} let (local Z: felt, local RLC_coeffs: felt*) = get_Z_and_RLC_from_transcript( - transcript_start=cast(values_ptr, felt*) + constants_ptr_len, - poseidon_indexes_ptr=poseidon_indexes_ptr, - n_elements_in_transcript=commitments_len, - n_equations=N_Euclidean_equations, + transcript_start=cast(values_ptr, felt*) + circuit.constants_ptr_len * N_LIMBS + + circuit.input_len, + poseidon_indexes_ptr=circuit.poseidon_indexes_ptr, + n_elements_in_transcript=circuit.commitments_len / N_LIMBS, + n_equations=circuit.N_Euclidean_equations, + init_hash=init_hash, ); - tempvar range_check96_ptr = range_check96_ptr + constants_ptr_len * N_LIMBS + input_len + - commitments_len * N_LIMBS; - write_felts_to_value_segment(values=&Z, n=1); - write_felts_to_value_segment(values=RLC_coeffs, n=N_Euclidean_equations); + %{ + # Sanity Check : + assert ids.Z == EXTF_MOD_CIRCUIT.transcript.continuable_hash, f"Z for circuit {EXTF_MOD_CIRCUIT.name} does not match {hex(ids.Z)} {hex(EXTF_MOD_CIRCUIT.transcript.continuable_hash)}" + %} + + tempvar range_check96_ptr = range_check96_ptr + circuit.constants_ptr_len * N_LIMBS + + circuit.input_len + circuit.commitments_len + circuit.witnesses_len; + write_felts_to_value_segment(values_start=&Z, n=1); + write_felts_to_value_segment(values_start=RLC_coeffs, n=circuit.N_Euclidean_equations); + %{ print(f"\tZ and felt252 written to value segment") %} + %{ print(f"\tRunning ModuloBuiltin circuit...") %} run_mod_p_circuit( p=p, values_ptr=values_ptr, - add_mod_offsets_ptr=add_offsets_ptr, - add_mod_n=add_mod_n, - mul_mod_offsets_ptr=mul_offsets_ptr, - mul_mod_n=mul_mod_n, + add_mod_offsets_ptr=circuit.add_offsets_ptr, + add_mod_n=circuit.add_mod_n, + mul_mod_offsets_ptr=circuit.mul_offsets_ptr, + mul_mod_n=circuit.mul_mod_n, ); + tempvar range_check96_ptr = range_check96_ptr + ( + circuit.add_mod_n + circuit.mul_mod_n - circuit.n_assert_eq + ) * N_LIMBS; - // assert_limbs_at_index_are_equal( - // values_ptr, left_assert_eq_offsets_ptr, right_assert_eq_offsets_ptr, assert_eq_len - // ); - - return 0; + let (output: felt*) = retrieve_output( + values_segment=values_ptr, + output_offsets_ptr=circuit.output_offsets_ptr, + n=circuit.output_len, + continuous_output=circuit.continuous_output, + ); + return (output=output, Z=Z); } diff --git a/src/pairing.cairo b/src/pairing.cairo index 473c8a94..6caea247 100644 --- a/src/pairing.cairo +++ b/src/pairing.cairo @@ -1,15 +1,162 @@ from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.memcpy import memcpy from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin, UInt384 -from starkware.cairo.common.modulo import run_mod_p_circuit -from src.definitions import get_P, E12D, ExtFCircuitInfo, get_final_exp_circuit -from src.utils import ( - get_Z_and_RLC_from_transcript, - write_felts_to_value_segment, - assert_limbs_at_index_are_equal, +from starkware.cairo.common.math import assert_nn_le +from starkware.cairo.common.math_cmp import is_in_range +from src.definitions import E12D, E6D, is_zero_E6D, one_E6D, zero_E12D, one_E12D, G1G2Pair +from starkware.cairo.common.alloc import alloc +from src.precompiled_circuits.final_exp import ( + get_FINAL_EXP_PART_1_circuit, + get_FINAL_EXP_PART_2_circuit, ) +from src.ec_ops import is_on_curve_g1_g2 + +from src.precompiled_circuits.multi_miller_loop import ( + get_MILLER_LOOP_N1_circuit, + get_MILLER_LOOP_N2_circuit, + get_MILLER_LOOP_N3_circuit, + get_COMPUTE_DOUBLE_PAIR_LINES_circuit, + get_ACCUMULATE_SINGLE_PAIR_LINES_circuit, +) +const TRUE = 1; +const FALSE = 0; + +from src.modulo_circuit import ( + run_extension_field_modulo_circuit, + run_extension_field_modulo_circuit_continuation, +) + +func multi_pairing{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}(input: G1G2Pair*, n: felt, curve_id: felt) -> (res: E12D) { + alloc_locals; + let is_n_pair_supported = is_in_range(n, 1, 4); + if (is_n_pair_supported == FALSE) { + let (local res: E12D) = zero_E12D(); + return (res=res); + } + let (all_on_curve) = all_g1_g2_pairs_are_on_curve(input, n, curve_id); + if (all_on_curve == FALSE) { + let (res) = zero_E12D(); + return (res=res); + } + let (m) = multi_miller_loop(cast(input, felt*), n, curve_id); + + let (f) = final_exponentiation(m, curve_id); + + return (res=f); +} + +func all_g1_g2_pairs_are_on_curve{ + range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* +}(input: felt*, n: felt, curve_id: felt) -> (res: felt) { + alloc_locals; + if (n == 0) { + return (res=TRUE); + } else { + let (check) = is_on_curve_g1_g2(curve_id, input); + if (check == TRUE) { + return all_g1_g2_pairs_are_on_curve(input + G1G2Pair.SIZE, n - 1, curve_id); + } else { + return (res=FALSE); + } + } +} func multi_miller_loop{ - range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* -}() -> E12D { + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}(input: felt*, n: felt, curve_id: felt) -> (res: E12D*) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + if (n == 1) { + let (circuit) = get_MILLER_LOOP_N1_circuit(curve_id); + let (output: felt*, _) = run_extension_field_modulo_circuit(circuit, input); + return (res=cast(output, E12D*)); + } + if (n == 2) { + let (circuit) = get_COMPUTE_DOUBLE_PAIR_LINES_circuit(curve_id); + let (output: felt*, Z: felt) = run_extension_field_modulo_circuit(circuit, input); + let (circuit) = get_MILLER_LOOP_N2_circuit(curve_id); + let (output: felt*, _) = run_extension_field_modulo_circuit_continuation( + circuit, output, Z + ); + return (res=cast(output, E12D*)); + } + if (n == 3) { + let (circuit) = get_COMPUTE_DOUBLE_PAIR_LINES_circuit(curve_id); + let (output: felt*, Z: felt) = run_extension_field_modulo_circuit(circuit, input); + let output_end = output + circuit.output_len; + + let (circuit) = get_ACCUMULATE_SINGLE_PAIR_LINES_circuit(curve_id); + memcpy(dst=output_end, src=input + 2 * G1G2Pair.SIZE, len=6 * UInt384.SIZE); + let (output: felt*, Z: felt) = run_extension_field_modulo_circuit_continuation( + circuit, output, Z + ); + + let (circuit) = get_MILLER_LOOP_N3_circuit(curve_id); + let (output: felt*, _) = run_extension_field_modulo_circuit_continuation( + circuit, output, Z + ); + return (res=cast(output, E12D*)); + } + // n >= 3 not implemented. Compose with n=3, n=2, n=1 and fp_12 mul. + let (local res: E12D) = zero_E12D(); + return (res=&res); +} + +func final_exponentiation{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}(input: E12D*, curve_id: felt) -> (res: E12D) { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + + local num: E6D = E6D( + v0=input.w0, v1=input.w2, v2=input.w4, v3=input.w6, v4=input.w8, v5=input.w10 + ); + local den: E6D = E6D( + v0=input.w1, v1=input.w3, v2=input.w5, v3=input.w7, v4=input.w9, v5=input.w11 + ); + let (local circuit_input: felt*) = alloc(); + memcpy(dst=circuit_input, src=cast(&num, felt*), len=24); + + let (den_is_zero) = is_zero_E6D(den, curve_id); + if (den_is_zero == TRUE) { + let (local one_E6: E6D) = one_E6D(); + memcpy(dst=circuit_input + 24, src=cast(&one_E6, felt*), len=24); + } else { + memcpy(dst=circuit_input + 24, src=cast(&den, felt*), len=24); + } + + let (local circuit) = get_FINAL_EXP_PART_1_circuit(curve_id); + let (output: felt*, Z: felt) = run_extension_field_modulo_circuit(circuit, circuit_input); + // %{ + // part1 = pack_bigint_ptr(memory, ids.output, 4, 2**96, ids.circuit.output_len//4) + // for x in part1: + // print(f"T0/T2/_SUM = {hex(x)}") + // %} + let _sum = [cast(output + 2 * E6D.SIZE, E6D*)]; + let (_sum_is_zero) = is_zero_E6D(_sum, curve_id); + + if (_sum_is_zero == TRUE) { + let (one_E12: E12D) = one_E12D(); + return (res=one_E12); + } else { + let (circuit) = get_FINAL_EXP_PART_2_circuit(curve_id); + let (output: felt*, _: felt) = run_extension_field_modulo_circuit_continuation( + circuit, output, Z + ); + return (res=[cast(output, E12D*)]); + } } diff --git a/src/utils.cairo b/src/utils.cairo index be5c2218..86618620 100644 --- a/src/utils.cairo +++ b/src/utils.cairo @@ -8,168 +8,202 @@ func get_Z_and_RLC_from_transcript{poseidon_ptr: PoseidonBuiltin*, range_check96 poseidon_indexes_ptr: felt*, n_elements_in_transcript: felt, n_equations: felt, + init_hash: felt, ) -> (Z: felt, random_linear_combination_coefficients: felt*) { alloc_locals; tempvar poseidon_start = poseidon_ptr; let (Z: felt) = hash_full_transcript_and_get_Z( - limbs_ptr=transcript_start, n=n_elements_in_transcript + limbs_ptr=transcript_start, n=n_elements_in_transcript, init_hash=init_hash ); + let (RLC_coeffs: felt*) = retrieve_random_coefficients( - poseidon_start, poseidon_indexes_ptr=poseidon_indexes_ptr, n=n_equations + poseidon_start=poseidon_start, poseidon_indexes_ptr=poseidon_indexes_ptr, n=n_equations ); + return (Z=Z, random_linear_combination_coefficients=RLC_coeffs); } -func hash_full_transcript_and_get_Z{poseidon_ptr: PoseidonBuiltin*}(limbs_ptr: felt*, n: felt) -> ( - Z: felt -) { +func hash_full_transcript_and_get_Z{poseidon_ptr: PoseidonBuiltin*}( + limbs_ptr: felt*, n: felt, init_hash: felt +) -> (Z: felt) { alloc_locals; - %{ print(f"N elemts in transcript : {ids.n} ") %} - local two = 2; - let input_hash = 14; - + // %{ print(f"N elemts in transcript : {ids.n} ") %} + local ptr: felt* = cast(poseidon_ptr, felt*); + // %{ + // from src.hints.io import pack_bigint_ptr + // to_hash=pack_bigint_ptr(memory, ids.limbs_ptr, ids.N_LIMBS, ids.BASE, ids.n) + // for e in to_hash: + // print(f"Will Hash {hex(e)}") + // %} // Initialisation: - assert poseidon_ptr[0].input = PoseidonBuiltinState( - limbs_ptr[0] * limbs_ptr[1], input_hash, two - ); - assert poseidon_ptr[1].input = PoseidonBuiltinState( - limbs_ptr[2] * limbs_ptr[3], poseidon_ptr[0].output.s0, two - ); + %{ + for i in range(2*ids.n -1): + memory[ids.ptr + 2 + 6*i] = 2 + memory[ids.ptr + 8 + 6*i] = 2 + %} + assert ptr[0] = limbs_ptr[0] * limbs_ptr[1]; + assert ptr[1] = init_hash; + assert ptr[6] = limbs_ptr[2] * limbs_ptr[3]; + assert ptr[7] = ptr[3]; + + %{ i=0 %} tempvar limbs_ptr: felt* = limbs_ptr + 4; - tempvar i = 2; + tempvar pos_ptr: felt* = ptr + 2 * PoseidonBuiltin.SIZE; hash_limbs_2_by_2: let limbs_ptr: felt* = cast([ap - 2], felt*); - let i = [ap - 1]; + let pos_ptr: felt* = cast([ap - 1], felt*); %{ - print(ids.i/2, "/", ids.n) - memory[ap] = 1 if ids.i == 2*ids.n else 0 + i+=1 + memory[ap] = 1 if i == ids.n else 0 %} jmp end_loop if [ap] != 0, ap++; - assert poseidon_ptr[i].input = PoseidonBuiltinState( - limbs_ptr[0] * limbs_ptr[1], poseidon_ptr[i - 1].output.s0, two - ); - assert poseidon_ptr[i + 1].input = PoseidonBuiltinState( - limbs_ptr[2] * limbs_ptr[3], poseidon_ptr[i].output.s0, two - ); + assert [pos_ptr] = limbs_ptr[0] * limbs_ptr[1]; + assert [pos_ptr + 1] = [pos_ptr - 3]; + assert [pos_ptr + 6] = limbs_ptr[2] * limbs_ptr[3]; + assert [pos_ptr + 7] = [pos_ptr + 3]; [ap] = limbs_ptr + 4, ap++; - [ap] = i + 2, ap++; + [ap] = pos_ptr + 2 * PoseidonBuiltin.SIZE, ap++; jmp hash_limbs_2_by_2; end_loop: - // let i = [ap - 1]; - assert i = 2 * n; - %{ print(f"i: {ids.i}, n:{ids.n}") %} - tempvar poseidon_ptr = poseidon_ptr + PoseidonBuiltin.SIZE * n * 2; + assert 2 * n * PoseidonBuiltin.SIZE = cast(pos_ptr, felt) - cast(ptr, felt); + tempvar poseidon_ptr = cast(pos_ptr, PoseidonBuiltin*); tempvar res = [poseidon_ptr - PoseidonBuiltin.SIZE].output.s0; + return (Z=res); } func retrieve_random_coefficients( - poseidon_ptr: PoseidonBuiltin*, poseidon_indexes_ptr: felt*, n: felt + poseidon_start: PoseidonBuiltin*, poseidon_indexes_ptr: felt*, n: felt ) -> (coefficients: felt*) { alloc_locals; - let (local coefficients: felt*) = alloc(); + let (local coefficients_start: felt*) = alloc(); + local ptr: felt* = cast(poseidon_start, felt*); - tempvar i = 0; + %{ i=0 %} + assert [coefficients_start] = [ptr + [poseidon_indexes_ptr]]; + tempvar coefficients = coefficients_start + 1; + tempvar poseidon_indexes_ptr = poseidon_indexes_ptr + 1; get_s1_loop: - let i = [ap - 1]; - %{ memory[ap] = 1 if ids.i == ids.n else 0 %} + let coefficients = cast([ap - 2], felt*); + let poseidon_indexes_ptr = cast([ap - 1], felt*); + %{ + i+=1 + memory[ap] = 1 if i == ids.n else 0 + %} jmp end if [ap] != 0, ap++; - assert coefficients[i] = poseidon_ptr[poseidon_indexes_ptr[i]].output.s1; - [ap] = i + 1, ap++; + assert [coefficients] = [ptr + [poseidon_indexes_ptr]]; + [ap] = coefficients + 1, ap++; + [ap] = poseidon_indexes_ptr + 1, ap++; jmp get_s1_loop; end: - assert i = n; - return (coefficients=coefficients); + assert n = cast(coefficients, felt) - cast(coefficients_start, felt); + // %{ + // from src.hints.io import pack_bigint_ptr + // array=pack_bigint_ptr(memory, ids.coefficients, 1, ids.BASE, ids.n) + // for i,e in enumerate(array): + // print(f"CAIRO Using c_{i} = {hex(e)}") + // %} + return (coefficients=coefficients_start); } -func write_felts_to_value_segment{range_check96_ptr: felt*}(values: felt*, n: felt) -> () { +func write_felts_to_value_segment{range_check96_ptr: felt*}(values_start: felt*, n: felt) -> () { alloc_locals; local stark_min_1_d2 = STARK_MIN_ONE_D2; local n_rc_per_felt = N_LIMBS + 1; - tempvar i = 0; + %{ i=0 %} + tempvar values = values_start; + tempvar rc_96_ptr = range_check96_ptr; loop: - let i = [ap - 1]; - %{ memory[ap] = 1 if ids.i == ids.n else 0 %} + let values = cast([ap - 2], felt*); + let rc_96_ptr = cast([ap - 1], felt*); + %{ + memory[ap] = 1 if i == ids.n else 0 + i+=1 + %} jmp end if [ap] != 0, ap++; - tempvar offset = i * n_rc_per_felt; - - let d0 = [range_check96_ptr + offset]; - let d1 = [range_check96_ptr + offset + 1]; - let d2 = [range_check96_ptr + offset + 2]; + let d0 = [rc_96_ptr]; + let d1 = [rc_96_ptr + 1]; + let d2 = [rc_96_ptr + 2]; %{ from src.hints.io import bigint_split - felt_val = memory[ids.values+ids.i] - print(f"felt val : {felt_val}") + felt_val = memory[ids.values_start+i-1] limbs = bigint_split(felt_val, ids.N_LIMBS, ids.BASE) assert limbs[3] == 0 ids.d0, ids.d1, ids.d2 = limbs[0], limbs[1], limbs[2] %} - assert [range_check96_ptr + offset + 3] = 0; - assert [range_check96_ptr + offset + 4] = stark_min_1_d2 - d2; - assert 0 = values[i] - (d0 + d1 * 2 ** 96 + d2 * (2 ** 96) ** 2); + assert [rc_96_ptr + 3] = 0; + assert [rc_96_ptr + 4] = stark_min_1_d2 - d2; + assert [values] = (d0 + d1 * BASE + d2 * BASE ** 2); if (d2 == stark_min_1_d2) { + // Take advantage of Cairo prime structure. STARK_MIN_ONE = 0 + 0 * BASE + stark_min_1_d2 * (BASE)**2. assert d1 = 0; assert d2 = 0; - [ap] = i + 1, ap++; + [ap] = values + 1, ap++; + [ap] = rc_96_ptr + n_rc_per_felt, ap++; } else { - [ap] = i + 1, ap++; + [ap] = values + 1, ap++; + [ap] = rc_96_ptr + n_rc_per_felt, ap++; } jmp loop; end: - assert i = n; - %{ print(f"RangeCheckptr:{ids.range_check96_ptr}", ids.n, ids.n_rc_per_felt) %} - tempvar range_check96_ptr = range_check96_ptr + n * n_rc_per_felt; + assert n = cast(values, felt) - cast(values_start, felt); + tempvar range_check96_ptr = rc_96_ptr; return (); } -func assert_limbs_at_index_are_equal{}( - values_segment: felt*, left_assert_eq_offsets: felt*, right_assert_eq_offsets: felt*, n: felt -) -> () { +func retrieve_output{}( + values_segment: felt*, output_offsets_ptr: felt*, n: felt, continuous_output: felt +) -> (output: felt*) { + if (continuous_output != 0) { + let offset = output_offsets_ptr[0]; + // %{ print(f"Continuous output! start value : {hex(memory[ids.values_segment + ids.offset])} Size: {ids.n//4} offset:{ids.offset}") %} + return (cast(values_segment + offset, felt*),); + } alloc_locals; + let (local output: felt*) = alloc(); local one = 1; local two = 2; local three = 3; tempvar i = 0; + tempvar output_offsets = output_offsets_ptr; loop: - let i = [ap - 1]; - %{ memory[ap] = 1 if ids.i == ids.n else 0 %} + let i = [ap - 2]; + let output_offsets = cast([ap - 1], felt*); + %{ + index = memory[ids.output_offsets_ptr+ids.i] + # print(f"Output {ids.i}/{ids.n} Index : {index}") + memory[ap] = 1 if ids.i == ids.n else 0 + %} jmp end if [ap] != 0, ap++; tempvar i_plus_one = i + one; tempvar i_plus_two = i + two; tempvar i_plus_three = i + three; - assert values_segment[left_assert_eq_offsets[i]] - values_segment[ - right_assert_eq_offsets[i] - ] = 0; - assert values_segment[left_assert_eq_offsets[i_plus_one]] - values_segment[ - right_assert_eq_offsets[i_plus_one] - ] = 0; - assert values_segment[left_assert_eq_offsets[i_plus_two]] - values_segment[ - right_assert_eq_offsets[i_plus_two] - ] = 0; - assert values_segment[left_assert_eq_offsets[i_plus_three]] - values_segment[ - right_assert_eq_offsets[i_plus_three] - ] = 0; - - [ap] = i + 1, ap++; + assert output[i] = values_segment[[output_offsets]]; + assert output[i_plus_one] = values_segment[[output_offsets] + one]; + assert output[i_plus_two] = values_segment[[output_offsets] + two]; + assert output[i_plus_three] = values_segment[[output_offsets] + three]; + + [ap] = i + 4, ap++; + [ap] = output_offsets + 1, ap++; jmp loop; end: assert i = n; - return (); + return (output=output); } diff --git a/tests/cairo_programs/extf_circuit.cairo b/tests/cairo_programs/extf_circuit.cairo index 49afdd1c..a2a22e0a 100644 --- a/tests/cairo_programs/extf_circuit.cairo +++ b/tests/cairo_programs/extf_circuit.cairo @@ -4,9 +4,16 @@ from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.alloc import alloc -from src.modulo_circuit import run_extension_field_modulo_circuit -from src.definitions import bn, bls, UInt384, one_E12D, N_LIMBS, BASE -from src.precompiled_circuits.sample import get_sample_circuit +from src.modulo_circuit import ( + run_extension_field_modulo_circuit, + run_extension_field_modulo_circuit_continuation, +) +from src.definitions import bn, bls, UInt384, N_LIMBS, BASE, E12D + +from src.precompiled_circuits.extf_mul import get_FP12_MUL_circuit + +from src.modulo_circuit import ExtensionFieldModuloCircuit + func main{ range_check_ptr, poseidon_ptr: PoseidonBuiltin*, @@ -16,23 +23,50 @@ func main{ }() { alloc_locals; let (__fp__, _) = get_fp_and_pc(); - let (local input: felt*) = alloc(); - local input_len: felt; - local circuit_id = 1; + let (local input_bn: felt*) = alloc(); + let (local input_bls: felt*) = alloc(); + + local expected_bn: E12D; + local expected_bls: E12D; + %{ from random import randint import random from src.definitions import CURVES, PyFelt - from src.hints.io import bigint_split, flatten + from src.hints.io import bigint_split, flatten, pack_e12d, fill_e12d random.seed(0) - p = CURVES[ids.bn.CURVE_ID].p - X=[PyFelt(randint(0, p - 1), p) for _ in range(6)] - X=flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in X]) - print(X, len(X)) - segments.write_arg(ids.input, X) - ids.input_len = len(X) + + def generate_input_for_fp_mul(ptr:object, curve_id: int, extension_degree:int) -> list: + p = CURVES[curve_id].p + X = [PyFelt(randint(0, p - 1), p) for _ in range(2*extension_degree)] + X = flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in X]) + segments.write_arg(ptr, X) + return X + + generate_input_for_fp_mul(ids.input_bn, ids.bn.CURVE_ID, 12) + generate_input_for_fp_mul(ids.input_bls, ids.bls.CURVE_ID, 12) + #fill_e12d(ids.expected_bn, 4, 2**96) + #fill_e12d(ids.expected_bls, 4, 2**96) %} - let x = run_extension_field_modulo_circuit(input, input_len, bn.CURVE_ID, circuit_id); + let (circuit) = get_FP12_MUL_circuit(bn.CURVE_ID); + let (output, _) = run_extension_field_modulo_circuit(circuit, input_bn); + local res_bn: E12D = [cast(output, E12D*)]; + + // let (circuit) = get_FP12_MUL_circuit(bls.CURVE_ID); + // let (output, _) = run_extension_field_modulo_circuit(circuit, input_bls); + // local res_bls: E12D = [cast(output, E12D*)]; + // %{ + // res_bn = pack_e12d(ids.res_bn, 4, 2**96) + // res_bls = pack_e12d(ids.res_bls, 4, 2**96) + // print(f"res_bn: {res_bn}") + // print(f"expected_bn: {expected_outputs[0]}") + // print(f"res_bls: {res_bls}") + // print(f"expected_bls: {expected_outputs[1]}") + // assert res_bn == expected_outputs[0] + // assert res_bls == expected_outputs[1] + // print(f"Test passed") + // %} + return (); } diff --git a/tests/cairo_programs/modulo.cairo b/tests/cairo_programs/modulo.cairo index af5668fd..e21c561e 100644 --- a/tests/cairo_programs/modulo.cairo +++ b/tests/cairo_programs/modulo.cairo @@ -117,12 +117,7 @@ func apply_poly{ } func main{ - range_check_ptr, - bitwise_ptr: BitwiseBuiltin*, - poseidon_ptr: PoseidonBuiltin*, - range_check96_ptr: felt*, - add_mod_ptr: ModBuiltin*, - mul_mod_ptr: ModBuiltin*, + range_check_ptr, range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin* }() { alloc_locals; diff --git a/tests/cairo_programs/test_final_exp.cairo b/tests/cairo_programs/test_final_exp.cairo new file mode 100644 index 00000000..7f261328 --- /dev/null +++ b/tests/cairo_programs/test_final_exp.cairo @@ -0,0 +1,74 @@ +%builtins range_check poseidon range_check96 add_mod mul_mod + +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.alloc import alloc + +from src.definitions import bn, bls, UInt384, one_E12D, N_LIMBS, BASE, E12D + +from src.pairing import final_exponentiation +from src.modulo_circuit import ExtensionFieldModuloCircuit + +func main{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}() { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let (local input_bn: felt*) = alloc(); + let (local input_bls: felt*) = alloc(); + %{ + from random import randint + import random + from tools.gnark_cli import GnarkCLI + from src.definitions import CURVES, PyFelt, CurveID, get_base_field, tower_to_direct + from src.hints.io import bigint_split, flatten, pack_e12d + random.seed(0) + + clis = [GnarkCLI(CurveID(ids.bn.CURVE_ID)), GnarkCLI(CurveID(ids.bls.CURVE_ID))] + inputs = [] + expected_outputs = [] + for cli in clis: + order = CURVES[cli.curve_id.value].n + field = get_base_field(cli.curve_id.value) + pairs = [] + n_pairs = 1 + for _ in range(n_pairs): + n1, n2 = randint(1, order), randint(1, order) + pairs.extend(cli.nG1nG2_operation(n1, n2, raw=True)) + + XT = cli.miller(input=pairs, n_pairs=1) + ET = cli.pair(input=pairs, n_pairs=1) + XT = [field(x) for x in XT] + ET = [field(x) for x in ET] + XD = tower_to_direct(XT, cli.curve_id.value, 12) + ED = tower_to_direct(ET, cli.curve_id.value, 12) + inputs.append(XD) + expected_outputs.append([x.value for x in ED]) + + + + segments.write_arg(ids.input_bn, flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in inputs[0]])) + segments.write_arg(ids.input_bls, flatten([bigint_split(x.value, ids.N_LIMBS, ids.BASE) for x in inputs[1]])) + %} + + let (local res_bn: E12D) = final_exponentiation(cast(input_bn, E12D*), bn.CURVE_ID); + + let (local res_bls: E12D) = final_exponentiation(cast(input_bls, E12D*), bls.CURVE_ID); + + %{ + res_bn = pack_e12d(ids.res_bn, 4, 2**96) + res_bls = pack_e12d(ids.res_bls, 4, 2**96) + assert res_bn == expected_outputs[0] + assert res_bls == expected_outputs[1] + #print(f"res_bn: {res_bn}") + #print(f"expected_bn: {expected_outputs[0]}\n") + #print(f"res_bls: {res_bls}") + #print(f"expected_bls: {expected_outputs[1]}") + print(f"Test Passed\n") + %} + return (); +} diff --git a/tests/cairo_programs/test_pairing.cairo b/tests/cairo_programs/test_pairing.cairo new file mode 100644 index 00000000..2c27c8f4 --- /dev/null +++ b/tests/cairo_programs/test_pairing.cairo @@ -0,0 +1,66 @@ +%builtins range_check poseidon range_check96 add_mod mul_mod + +from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, ModBuiltin +from starkware.cairo.common.registers import get_fp_and_pc +from starkware.cairo.common.alloc import alloc + +from src.definitions import bn, bls, UInt384, one_E12D, N_LIMBS, BASE, E12D + +from src.pairing import multi_pairing, G1G2Pair +from src.modulo_circuit import ExtensionFieldModuloCircuit + +func main{ + range_check_ptr, + poseidon_ptr: PoseidonBuiltin*, + range_check96_ptr: felt*, + add_mod_ptr: ModBuiltin*, + mul_mod_ptr: ModBuiltin*, +}() { + alloc_locals; + let (__fp__, _) = get_fp_and_pc(); + let (local inputs: felt*) = alloc(); + + local n_pairs: felt; + local curve_id: felt; + %{ + from tools.gnark_cli import GnarkCLI + from src.definitions import CURVES, PyFelt, CurveID, get_base_field, tower_to_direct + from src.hints.io import bigint_split, flatten, pack_e12d + + ids.n_pairs = program_input['n_pairs'] + ids.curve_id=program_input['curve_id'] + n1s, n2s = program_input['n1s'], program_input['n2s'] + + def prepare_inputs_and_expected_outputs(cli, n_pairs): + order = CURVES[cli.curve_id.value].n + field = get_base_field(cli.curve_id.value) + pairs = [] + for i in range(n_pairs): + n1, n2 = n1s[i], n2s[i] + pairs.extend(cli.nG1nG2_operation(n1, n2, raw=True)) + + inputs = flatten([bigint_split(x, ids.N_LIMBS, ids.BASE) for x in pairs]) + ET = cli.pair(input=pairs, n_pairs=n_pairs) + ET = [field(x) for x in ET] + ED = tower_to_direct(ET, cli.curve_id.value, 12) + + expected_outputs=[x.value for x in ED] + return inputs, expected_outputs + + cli = GnarkCLI(CurveID(ids.curve_id)) + inputs, expected_outputs = prepare_inputs_and_expected_outputs(cli, ids.n_pairs) + + segments.write_arg(ids.inputs, inputs) + %} + + let (local res: E12D) = multi_pairing(cast(inputs, G1G2Pair*), n_pairs, curve_id); + %{ + res = pack_e12d(ids.res, 4, 2**96) + #print(f"res: {[hex(x) for x in res]}") + #print(f"expected: {[hex(x) for x in expected_outputs]}\n") + assert res == expected_outputs, f"res: {res}, expected: {expected_outputs}" + %} + + %{ print(f"Test Passed\n") %} + return (); +}