Skip to content

Commit

Permalink
PySumCheck test & fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Nov 5, 2024
1 parent 776d7be commit ca65bba
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 86 deletions.
13 changes: 10 additions & 3 deletions hydra/garaga/modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,19 @@ def is_empty_circuit(self) -> bool:

def write_element(
self,
elmt: PyFelt,
elmt: PyFelt | int,
write_source: WriteOps = WriteOps.INPUT,
instruction: ModuloCircuitInstruction | None = None,
) -> ModuloCircuitElement:
"""
Register an emulated field element to the circuit given its value and the write source.
Returns a ModuloCircuitElement representing the written element with its offset as identifier.
"""
assert isinstance(elmt, PyFelt), f"Expected PyFelt, got {type(elmt)}"
assert isinstance(elmt, PyFelt) or isinstance(
elmt, int
), f"Expected PyFelt or int, got {type(elmt)}"
if isinstance(elmt, int):
elmt = self.field(elmt)
value_offset = self.values_segment.write_to_segment(
ValueSegmentItem(
elmt,
Expand Down Expand Up @@ -432,7 +436,10 @@ def write_struct(
return result

def write_elements(
self, elmts: list[PyFelt], operation: WriteOps, sparsity: list[int] = None
self,
elmts: list[PyFelt],
operation: WriteOps = WriteOps.INPUT,
sparsity: list[int] = None,
) -> list[ModuloCircuitElement]:
if sparsity is not None:
assert len(sparsity) == len(
Expand Down
24 changes: 18 additions & 6 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from pathlib import Path

from garaga.definitions import CurveID
from garaga.precompiled_circuits.compilable_circuits.base import (
Expand Down Expand Up @@ -61,6 +62,8 @@
)
from garaga.starknet.cli.utils import create_directory

STARKNET_DIR = Path(__file__).parent.parent / "starknet"


class CircuitID(Enum):
DUMMY = int.from_bytes(b"dummy", "big")
Expand Down Expand Up @@ -285,12 +288,6 @@ class CircuitID(Enum):
"filename": "isogeny",
"curve_ids": [CurveID.BLS12_381],
},
CircuitID.HONK_SUMCHECK_CIRCUIT: {
"class": SumCheckCircuit,
"params": [{"vk": HonkVk.mock()}],
"filename": "honk_circuits",
"curve_ids": [CurveID.GRUMPKIN],
},
CircuitID.TOWER_MILLER_BIT0: {
"class": TowerMillerBit0,
"params": [{"n_pairs": k} for k in [1]],
Expand Down Expand Up @@ -381,6 +378,21 @@ class CircuitID(Enum):
"filename": "tower_circuits",
"curve_ids": [CurveID.BLS12_381],
},
CircuitID.HONK_SUMCHECK_CIRCUIT: {
"class": SumCheckCircuit,
"params": [
{
"vk": HonkVk.from_bytes(
open(
f"{STARKNET_DIR}/honk_contract_generator/examples/vk_ultra_keccak.bin",
"rb",
).read()
)
}
],
"filename": "honk_circuits",
"curve_ids": [CurveID.GRUMPKIN],
},
}


Expand Down
178 changes: 101 additions & 77 deletions hydra/garaga/precompiled_circuits/honk_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sha3

import garaga.modulo_circuit_structs as structs
from garaga.definitions import CURVES, CurveID, G1Point, G2Point
from garaga.extension_field_modulo_circuit import ModuloCircuit, ModuloCircuitElement

Expand Down Expand Up @@ -66,31 +67,6 @@ def __post_init__(self):
assert len(self.gemini_fold_comms) == CONST_PROOF_SIZE_LOG_N - 1
assert len(self.gemini_a_evaluations) == CONST_PROOF_SIZE_LOG_N

@classmethod
def mock(cls) -> "HonkProof":
return cls(
circuit_size=0,
public_inputs_size=0,
public_inputs_offset=0,
w1=G1Point.get_nG(CurveID.GRUMPKIN, 1),
w2=G1Point.get_nG(CurveID.GRUMPKIN, 2),
w3=G1Point.get_nG(CurveID.GRUMPKIN, 3),
w4=G1Point.get_nG(CurveID.GRUMPKIN, 4),
z_perm=G1Point.get_nG(CurveID.GRUMPKIN, 5),
lookup_read_counts=G1Point.get_nG(CurveID.GRUMPKIN, 6),
lookup_read_tags=G1Point.get_nG(CurveID.GRUMPKIN, 7),
lookup_inverses=G1Point.get_nG(CurveID.GRUMPKIN, 8),
sumcheck_univariates=[
[i] * CONST_PROOF_SIZE_LOG_N
for i in range(BATCHED_RELATION_PARTIAL_LENGTH)
],
sumcheck_evaluations=[0] * NUMBER_OF_ENTITIES,
gemini_fold_comms=[0] * (CONST_PROOF_SIZE_LOG_N - 1),
gemini_a_evaluations=[0] * CONST_PROOF_SIZE_LOG_N,
shplonk_q=G1Point.get_nG(CurveID.GRUMPKIN, 9),
kzg_quotient=G1Point.get_nG(CurveID.GRUMPKIN, 10),
)

@classmethod
def from_bytes(cls, bytes: bytes) -> "HonkProof":
n_elements = int.from_bytes(bytes[:4], "big")
Expand Down Expand Up @@ -191,6 +167,55 @@ def parse_g1_proof_point(i: int) -> G1Point:
kzg_quotient=kzg_quotient,
)

def to_circuit_elements(self, circuit: ModuloCircuit) -> "HonkProof":
"""Convert everything to ModuloCircuitElements given a circuit."""
return HonkProof(
circuit_size=self.circuit_size,
public_inputs_size=self.public_inputs_size,
public_inputs_offset=circuit.write_element(self.public_inputs_offset),
public_inputs=circuit.write_elements(self.public_inputs),
w1=circuit.write_struct(structs.G1PointCircuit.from_G1Point("w1", self.w1)),
w2=circuit.write_struct(structs.G1PointCircuit.from_G1Point("w2", self.w2)),
w3=circuit.write_struct(structs.G1PointCircuit.from_G1Point("w3", self.w3)),
w4=circuit.write_struct(structs.G1PointCircuit.from_G1Point("w4", self.w4)),
z_perm=circuit.write_struct(
structs.G1PointCircuit.from_G1Point("z_perm", self.z_perm)
),
lookup_read_counts=circuit.write_struct(
structs.G1PointCircuit.from_G1Point(
"lookup_read_counts", self.lookup_read_counts
)
),
lookup_read_tags=circuit.write_struct(
structs.G1PointCircuit.from_G1Point(
"lookup_read_tags", self.lookup_read_tags
)
),
lookup_inverses=circuit.write_struct(
structs.G1PointCircuit.from_G1Point(
"lookup_inverses", self.lookup_inverses
)
),
sumcheck_univariates=[
circuit.write_elements(univariate)
for univariate in self.sumcheck_univariates
],
sumcheck_evaluations=circuit.write_elements(self.sumcheck_evaluations),
gemini_fold_comms=[
circuit.write_struct(
structs.G1PointCircuit.from_G1Point(f"gemini_fold_comm_{i}", comm)
)
for i, comm in enumerate(self.gemini_fold_comms)
],
gemini_a_evaluations=circuit.write_elements(self.gemini_a_evaluations),
shplonk_q=circuit.write_struct(
structs.G1PointCircuit.from_G1Point("shplonk_q", self.shplonk_q)
),
kzg_quotient=circuit.write_struct(
structs.G1PointCircuit.from_G1Point("kzg_quotient", self.kzg_quotient)
),
)


@dataclass
class HonkVk:
Expand Down Expand Up @@ -237,43 +262,6 @@ def __repr__(self) -> str:
# def __str__(self) -> str:
# return self.__repr__()

@classmethod
def mock(cls, log_circuit_size: int = 16, public_inputs_size: int = 6) -> "HonkVk":
return cls(
name="mock",
circuit_size=2**log_circuit_size,
log_circuit_size=log_circuit_size,
public_inputs_size=public_inputs_size,
public_inputs_offset=1,
qm=G1Point.get_nG(CurveID.GRUMPKIN, 1),
qc=G1Point.get_nG(CurveID.GRUMPKIN, 2),
ql=G1Point.get_nG(CurveID.GRUMPKIN, 3),
qr=G1Point.get_nG(CurveID.GRUMPKIN, 4),
qo=G1Point.get_nG(CurveID.GRUMPKIN, 5),
q4=G1Point.get_nG(CurveID.GRUMPKIN, 6),
qArith=G1Point.get_nG(CurveID.GRUMPKIN, 7),
qDeltaRange=G1Point.get_nG(CurveID.GRUMPKIN, 8),
qAux=G1Point.get_nG(CurveID.GRUMPKIN, 9),
qElliptic=G1Point.get_nG(CurveID.GRUMPKIN, 10),
qLookup=G1Point.get_nG(CurveID.GRUMPKIN, 11),
qPoseidon2External=G1Point.get_nG(CurveID.GRUMPKIN, 12),
qPoseidon2Internal=G1Point.get_nG(CurveID.GRUMPKIN, 13),
s1=G1Point.get_nG(CurveID.GRUMPKIN, 14),
s2=G1Point.get_nG(CurveID.GRUMPKIN, 15),
s3=G1Point.get_nG(CurveID.GRUMPKIN, 16),
s4=G1Point.get_nG(CurveID.GRUMPKIN, 17),
id1=G1Point.get_nG(CurveID.GRUMPKIN, 18),
id2=G1Point.get_nG(CurveID.GRUMPKIN, 19),
id3=G1Point.get_nG(CurveID.GRUMPKIN, 20),
id4=G1Point.get_nG(CurveID.GRUMPKIN, 21),
t1=G1Point.get_nG(CurveID.GRUMPKIN, 22),
t2=G1Point.get_nG(CurveID.GRUMPKIN, 23),
t3=G1Point.get_nG(CurveID.GRUMPKIN, 24),
t4=G1Point.get_nG(CurveID.GRUMPKIN, 25),
lagrange_first=G1Point.get_nG(CurveID.GRUMPKIN, 26),
lagrange_last=G1Point.get_nG(CurveID.GRUMPKIN, 27),
)

@classmethod
def from_bytes(cls, bytes: bytes) -> "HonkVk":
circuit_size = int.from_bytes(bytes[0:8], "big")
Expand Down Expand Up @@ -336,6 +324,24 @@ def serialize_to_cairo(self, name: str = "vk") -> str:
code += "};"
return code

def to_circuit_elements(self, circuit: ModuloCircuit) -> "HonkVk":
return HonkVk(
name=self.name,
circuit_size=self.circuit_size,
log_circuit_size=self.log_circuit_size,
public_inputs_size=self.public_inputs_size,
public_inputs_offset=circuit.write_element(self.public_inputs_offset),
**{
field.name: circuit.write_struct(
structs.G1PointCircuit.from_G1Point(
field.name, getattr(self, field.name)
)
)
for field in fields(self)
if field.type == G1Point and field.name != "name"
},
)


class Sha3Transcript:
def __init__(self):
Expand All @@ -356,19 +362,19 @@ def update(self, data: bytes):

@dataclass
class HonkTranscript:
eta: ModuloCircuitElement
etaTwo: ModuloCircuitElement
etaThree: ModuloCircuitElement
beta: ModuloCircuitElement
gamma: ModuloCircuitElement
alphas: list[ModuloCircuitElement]
gate_challenges: list[ModuloCircuitElement]
eta: int | ModuloCircuitElement
etaTwo: int | ModuloCircuitElement
etaThree: int | ModuloCircuitElement
beta: int | ModuloCircuitElement
gamma: int | ModuloCircuitElement
alphas: list[int | ModuloCircuitElement]
gate_challenges: list[int | ModuloCircuitElement]
sum_check_u_challenges: list[ModuloCircuitElement]
rho: ModuloCircuitElement
gemini_r: ModuloCircuitElement
shplonk_nu: ModuloCircuitElement
shplonk_z: ModuloCircuitElement
public_inputs_delta: ModuloCircuitElement | None = None # Derived.
rho: int | ModuloCircuitElement
gemini_r: int | ModuloCircuitElement
shplonk_nu: int | ModuloCircuitElement
shplonk_z: int | ModuloCircuitElement
public_inputs_delta: int | None = None # Derived.

def __post_init__(self):
assert len(self.alphas) == NUMBER_OF_ALPHAS
Expand Down Expand Up @@ -571,6 +577,23 @@ def split_challenge(ch: bytes) -> tuple[int, int]:
public_inputs_delta=None,
)

def to_circuit_elements(self, circuit: ModuloCircuit) -> "HonkTranscript":
return HonkTranscript(
eta=circuit.write_element(self.eta),
etaTwo=circuit.write_element(self.etaTwo),
etaThree=circuit.write_element(self.etaThree),
beta=circuit.write_element(self.beta),
gamma=circuit.write_element(self.gamma),
alphas=circuit.write_elements(self.alphas),
gate_challenges=circuit.write_elements(self.gate_challenges),
sum_check_u_challenges=circuit.write_elements(self.sum_check_u_challenges),
rho=circuit.write_element(self.rho),
gemini_r=circuit.write_element(self.gemini_r),
shplonk_nu=circuit.write_element(self.shplonk_nu),
shplonk_z=circuit.write_element(self.shplonk_z),
public_inputs_delta=None,
)


class HonkVerifierCircuits(ModuloCircuit):
def __init__(
Expand Down Expand Up @@ -637,6 +660,7 @@ def compute_public_input_delta(
# - domain_size : part of vk.
# - offset : proof pub input offset
"""
assert len(public_inputs) > 0
num = self.set_or_get_constant(1)
den = self.set_or_get_constant(1)

Expand Down Expand Up @@ -1060,9 +1084,9 @@ def accumulate_delta_range_relation(
evaluations[6 + i] = self.product(
[
delta,
self.add(delta_1, self.set_or_get_constant(-1)),
self.add(delta_1, self.set_or_get_constant(-2)),
self.add(delta_1, self.set_or_get_constant(-3)),
self.add(delta, self.set_or_get_constant(-1)),
self.add(delta, self.set_or_get_constant(-2)),
self.add(delta, self.set_or_get_constant(-3)),
p[Wire.Q_RANGE],
domain_separator,
]
Expand Down
47 changes: 47 additions & 0 deletions tests/hydra/circuits/test_honk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from hydra.garaga.precompiled_circuits.honk_new import (
HonkProof,
HonkTranscript,
HonkVerifierCircuits,
HonkVk,
)

PATH = "hydra/garaga/starknet/honk_contract_generator/examples"


def test_sumcheck_circuit():
vk = HonkVk.from_bytes(open(f"{PATH}/vk_ultra_keccak.bin", "rb").read())
proof = HonkProof.from_bytes(open(f"{PATH}/proof_ultra_keccak.bin", "rb").read())
tp = HonkTranscript.from_proof(proof)

circuit = HonkVerifierCircuits(name="test", log_n=vk.log_circuit_size)

vk = vk.to_circuit_elements(circuit)
proof = proof.to_circuit_elements(circuit)
tp = tp.to_circuit_elements(circuit)

public_input_delta = circuit.compute_public_input_delta(
public_inputs=proof.public_inputs,
beta=tp.beta,
gamma=tp.gamma,
domain_size=vk.circuit_size,
offset=vk.public_inputs_offset,
)

rlc_check, check = circuit.verify_sum_check(
sumcheck_univariates=proof.sumcheck_univariates,
sumcheck_evaluations=proof.sumcheck_evaluations,
beta=tp.beta,
gamma=tp.gamma,
public_inputs_delta=public_input_delta,
eta=tp.eta,
eta_two=tp.etaTwo,
eta_three=tp.etaThree,
sum_check_u_challenges=tp.sum_check_u_challenges,
gate_challenges=tp.gate_challenges,
alphas=tp.alphas,
log_n=vk.log_circuit_size,
base_rlc=circuit.write_element(1234),
)

assert rlc_check.value == 0
assert check.value == 0

0 comments on commit ca65bba

Please sign in to comment.