Skip to content

Commit

Permalink
bls 3 pairs test passing + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Jul 21, 2024
1 parent d148171 commit eaf3e8f
Show file tree
Hide file tree
Showing 14 changed files with 6,386 additions and 529 deletions.
26 changes: 26 additions & 0 deletions hydra/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,32 @@ def gen_random_point(curve_id: CurveID) -> "G2Point":
)


@dataclass(slots=True)
class G1G2Pair:
p: G1Point
q: G2Point
curve_id: CurveID = None

def __post_init__(self):
if self.p.curve_id != self.q.curve_id:
raise ValueError("Points are not on the same curve")
self.curve_id = self.p.curve_id

def to_pyfelt_list(self) -> list[PyFelt]:
field = get_base_field(self.curve_id.value)
return [
field(x)
for x in [
self.p.x,
self.p.y,
self.q.x[0],
self.q.x[1],
self.q.y[0],
self.q.y[1],
]
]


# v^6 - 18v^3 + 82
# w^12 - 18w^6 + 82
# v^6 - 2v^3 + 2
Expand Down
18 changes: 13 additions & 5 deletions hydra/extension_field_modulo_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,19 @@ def __init__(

def _init_accumulator(self, extension_degree: int = None):
extension_degree = extension_degree or self.extension_degree
return EuclideanPolyAccumulator(
lhs=self.set_or_get_constant(0),
R=[self.set_or_get_constant(0)] * extension_degree,
R_evaluated=self.set_or_get_constant(0),
)
# Todo : Add compilation mode 1 support
if self.compilation_mode == 1:
return EuclideanPolyAccumulator(
lhs=None,
R=[None] * extension_degree,
R_evaluated=None,
)
else:
return EuclideanPolyAccumulator(
lhs=self.set_or_get_constant(0),
R=[self.set_or_get_constant(0)] * extension_degree,
R_evaluated=self.set_or_get_constant(0),
)

@property
def commitments(self):
Expand Down
100 changes: 77 additions & 23 deletions hydra/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from dataclasses import dataclass
from abc import ABC, abstractmethod

Expand All @@ -8,14 +9,23 @@
@dataclass(slots=True)
class Cairo1SerializableStruct(ABC):
name: str
elmts: list[ModuloCircuitElement | PyFelt]
elmts: list[ModuloCircuitElement | PyFelt | "Cairo1SerializableStruct"]

def __post_init__(self):
assert type(self.elmts) == list
assert type(self.name) == str
assert all(
isinstance(elmt, (ModuloCircuitElement, PyFelt)) for elmt in self.elmts
)
if isinstance(self.elmts, list):
if isinstance(self.elmts[0], Cairo1SerializableStruct):
assert all(
isinstance(elmt, self.elmts[0].__class__) for elmt in self.elmts
), f"All elements of {self.name} must be of the same type"

else:
assert all(
isinstance(elmt, (ModuloCircuitElement, PyFelt))
for elmt in self.elmts
), f"All elements of {self.name} must be of type ModuloCircuitElement or PyFelt"
else:
assert self.elmts == None

@property
def struct_name(self) -> str:
Expand Down Expand Up @@ -43,6 +53,36 @@ def __len__(self) -> int:
pass


class StructArray(Cairo1SerializableStruct):
@property
def struct_name(self) -> str:
return "Array<" + self.elmts[0].struct_name + ">"

def dump_to_circuit_input(self) -> str:
code = ""
for struct in self.elmts:
code += struct.dump_to_circuit_input()
return code

def __len__(self) -> int:
return sum(len(struct) for struct in self.elmts)

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
raise NotImplementedError

def serialize(self, raw: bool = False) -> str:
raw_struct = f"array!["
for struct in self.elmts:
raw_struct += struct.serialize(raw=True) + ","
raw_struct += "];\n"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.struct_name} = {raw_struct};\n"


class u384(Cairo1SerializableStruct):
def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 1
Expand All @@ -56,7 +96,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 1
return f"let {self.name}:{self.struct_name} = outputs.get_output({offset_to_reference_map[self.elmts[0].offset]});\n"
return f"let {self.name}:{self.struct_name} = outputs.get_output({offset_to_reference_map[self.elmts[0].offset]});"

def dump_to_circuit_input(self) -> str:
return f"circuit_inputs = circuit_inputs.next({self.name});\n"
Expand Down Expand Up @@ -85,7 +125,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 1
return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];\n"
return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];"

def dump_to_circuit_input(self) -> str:
code = f"""
Expand Down Expand Up @@ -114,7 +154,7 @@ def struct_name(self) -> str:

def serialize(self) -> str:
assert len(self.elmts) == 2
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{yInv: {int_to_u384(self.elmts[0].value)}, xNegOverY: {int_to_u384(self.elmts[1].value)}}};\n"
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{yInv: {int_to_u384(self.elmts[0].value)}, xNegOverY: {int_to_u384(self.elmts[1].value)}}};"

def serialize_input_signature(self):
return f"{self.name}:{self.struct_name}"
Expand All @@ -123,7 +163,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 2
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(2)])} }};\n"
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(2)])} }};"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down Expand Up @@ -167,7 +207,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 4
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(4)])} }};\n"
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(4)])} }};"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down Expand Up @@ -203,7 +243,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 2
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(2)])} }};\n"
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(2)])} }};"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down Expand Up @@ -239,7 +279,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 4
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(4)])} }};\n"
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(4)])} }};"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down Expand Up @@ -278,7 +318,7 @@ def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 6
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(6)])} }};\n"
return f"let {self.name}:{self.struct_name} = {self.struct_name} {{ {','.join([f'{self.members_names[i]}: outputs.get_output({offset_to_reference_map[self.elmts[i].offset]})' for i in range(6)])} }};"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down Expand Up @@ -310,12 +350,19 @@ def extract_from_circuit_output(
return code

def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 12
raw_struct = f"{self.__class__.__name__}{{{','.join([f'w{i}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
if self.elmts is None:
raw_struct = "Option::None"
if raw:
return raw_struct
else:
return f"let {self.name}:Option<{self.__class__.__name__}> = {raw_struct};\n"
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"
assert len(self.elmts) == 12
raw_struct = f"{self.__class__.__name__}{{{','.join([f'w{i}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down Expand Up @@ -347,12 +394,19 @@ def extract_from_circuit_output(
return code

def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 11
raw_struct = f"{self.__class__.__name__}{{{','.join([f'w{i}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
if self.elmts is None:
raw_struct = "Option::None"
if raw:
return raw_struct
else:
return f"let {self.name}:Option<{self.__class__.__name__}> = {raw_struct};\n"
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"
assert len(self.elmts) == 11
raw_struct = f"{self.__class__.__name__}{{{','.join([f'w{i}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.__class__.__name__} = {raw_struct};\n"

def dump_to_circuit_input(self) -> str:
code = ""
Expand Down
13 changes: 10 additions & 3 deletions hydra/poseidon_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def RLC_coeff(self):
self.poseidon_ptr_indexes.append(self.permutations_count - 1)
return self.s1

def hash_element(self, x: PyFelt | ModuloCircuitElement):
def hash_element(self, x: PyFelt | ModuloCircuitElement, debug: bool = False):
# print(f"Will Hash PYTHON {hex(x.value)}")
limbs = bigint_split(x.value, N_LIMBS, BASE)
if debug:
print(f"limbs : {limbs}")
self.s0, self.s1, self.s2 = hades_permutation(
self.s0 + limbs[0] + (BASE) * limbs[1],
self.s1 + limbs[2] + (BASE) * limbs[3],
Expand All @@ -66,12 +68,17 @@ def hash_element(self, x: PyFelt | ModuloCircuitElement):
return self.s0, self.s1

def hash_limbs_multi(
self, X: list[PyFelt | ModuloCircuitElement], sparsity: list[int] = None
self,
X: list[PyFelt | ModuloCircuitElement],
sparsity: list[int] = None,
debug: bool = False,
):
if sparsity:
X = [x for i, x in enumerate(X) if sparsity[i] != 0]
for X_elem in X:
self.hash_element(X_elem)
if debug:
print(f"\t s0 : {self.s0}")
self.hash_element(X_elem, debug=debug)
return None


Expand Down
20 changes: 15 additions & 5 deletions hydra/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def _run_circuit_inner(self, input: list[PyFelt]):
)
circuit.write_p_and_q(input)

m = circuit.multi_pairing_check(n_pairs)
m, _, _, _ = circuit.multi_pairing_check(n_pairs)

circuit.extend_output(m)
circuit.finalize_circuit()
Expand Down Expand Up @@ -1750,15 +1750,23 @@ def _run_circuit_inner(self, input: list[PyFelt]):
comment=f"c_n_minus_1 * ((Π(n-1,k) (Pk(z)) - R_n_minus_1(z))",
)

final_lhs = circuit.add(previous_lhs, lhs_n_minus_1)
final_lhs = circuit.add(
previous_lhs, lhs_n_minus_1, comment="previous_lhs + lhs_n_minus_1"
)
P_irr, P_irr_sparsity = circuit.write_sparse_constant_elements(
get_irreducible_poly(self.curve_id, 12).get_coeffs(),
)
P_of_z = circuit.eval_poly_in_precomputed_Z(P_irr, P_irr_sparsity)
P_of_z = circuit.eval_poly_in_precomputed_Z(
P_irr, P_irr_sparsity, poly_name="P_irr"
)

Q_of_z = circuit.eval_poly_in_precomputed_Z(Q)
Q_of_z = circuit.eval_poly_in_precomputed_Z(Q, poly_name="big_Q")

check = circuit.sub(final_lhs, circuit.mul(Q_of_z, P_of_z))
check = circuit.sub(
final_lhs,
circuit.mul(Q_of_z, P_of_z, comment="Q(z) * P(z)"),
comment="final_lhs - Q(z) * P(z)",
)

circuit.extend_struct_output(u384("final_check", elmts=[check]))
return circuit
Expand Down Expand Up @@ -1916,6 +1924,8 @@ def compilation_mode_to_file_header(mode: int) -> str:
CircuitModulus, AddInputResultTrait, CircuitInputs, CircuitDefinition,
CircuitData, CircuitInputAccumulator
};
use core::circuit::CircuitElement as CE;
use core::circuit::CircuitInput as CI;
use garaga::definitions::{get_a, get_b, get_p, get_g, get_min_one, G1Point, G2Point, E12D, E12DMulQuotient, G1G2Pair, BNProcessedPair, BLSProcessedPair, MillerLoopResultScalingFactor};
use core::option::Option;\n
"""
Expand Down
Loading

0 comments on commit eaf3e8f

Please sign in to comment.