Skip to content

Commit

Permalink
Support u288 struct for pairings & refactor groth16 calldata as felt2…
Browse files Browse the repository at this point in the history
…52 array (#187)

For more details, check #187
  • Loading branch information
feltroidprime authored Sep 9, 2024
1 parent 866281e commit 7e16413
Show file tree
Hide file tree
Showing 44 changed files with 19,188 additions and 20,171 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cairo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- uses: actions/checkout@v3
- uses: software-mansion/setup-scarb@v1
with:
scarb-version: "2.8.0"
scarb-version: "2.8.2"
- run: scarb fmt --check
working-directory: src/
- run: cd src/ && scarb test
2 changes: 1 addition & 1 deletion .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Setup Scarb
uses: software-mansion/setup-scarb@v1
with:
scarb-version: "2.8.0"
scarb-version: "2.8.2"
- name: Install dependencies
run: make setup

Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ venv
*.sage.py
*.idea
*.secrets
*.pb.gz

tools/garaga_rs/target/
tools/garaga_rs/Cargo.lock
tools/make/requirements.txt

.prev_tests_failed

src/cairo/target/
*target*
Scarb.lock
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ To get started with Garaga, you'll need to have some tools and dependencies inst

Ensure you have the following installed:
- [Python 3.10](https://www.python.org/downloads/) - /!\ Make sure `python3.10` is a valid command in your terminal. The core language used for development. Make sure you have the correct dependencies installed (in particular, GMP) for the `fastecdsa` python package. See [here](https://pypi.org/project/fastecdsa/#installing) for linux and [here](https://github.com/AntonKueltz/fastecdsa/issues/74) for macos.
- [Scarb 2.8.0](https://docs.swmansion.com/scarb/download.html) - The Cairo package manager. Comes with Cairo inside. Requires [Rust](https://www.rust-lang.org/tools/install).
- [Scarb 2.8.2](https://docs.swmansion.com/scarb/download.html) - The Cairo package manager. Comes with Cairo inside. Requires [Rust](https://www.rust-lang.org/tools/install).

##### Optionally :

Expand Down
2 changes: 1 addition & 1 deletion docs/gitbook/installation/developer-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ icon: wrench
To work with Garaga, you need the following dependencies : 

* Python 3.10. The command `python3.10` should be available and working in your terminal. 
* [Scarb](https://docs.swmansion.com/scarb/download.html) v2.8.0. 
* [Scarb](https://docs.swmansion.com/scarb/download.html) v2.8.2. 
* [Rust](https://www.rust-lang.org/tools/install)

Simply clone the [repository](https://github.com/keep-starknet-strange/garaga) :
Expand Down
6 changes: 4 additions & 2 deletions hydra/garaga/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,10 @@ def zero(self) -> PyFelt:
def one(self) -> PyFelt:
return PyFelt(1, self.p)

def random(self) -> PyFelt:
return PyFelt(random.randint(0, self.p - 1), self.p)
def random(self, max_value: int = None) -> PyFelt:
if max_value is None:
max_value = self.p - 1
return PyFelt(random.randint(0, max_value), self.p)

@property
def type(self) -> type[PyFelt]:
Expand Down
31 changes: 31 additions & 0 deletions hydra/garaga/hints/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ def to_int(value: str | int | bytes) -> int:
raise TypeError(f"Expected str, int, or bytes, got {type(value).__name__}")


def int_to_u2XX(x: int | PyFelt, curve_id: int = 0, as_hex=True) -> str:
if curve_id == 1:
return int_to_u384(x, as_hex)
else:
return int_to_u288(x, as_hex)


def int_to_u288(x: int | PyFelt, as_hex=True) -> str:
limbs = bigint_split(x, 3, 2**96)
if as_hex:
return f"u288{{limb0:{hex(limbs[0])}, limb1:{hex(limbs[1])}, limb2:{hex(limbs[2])}}}"
else:
return f"u288{{limb0:{limbs[0]}, limb1:{limbs[1]}, limb2:{limbs[2]}}}"


def int_to_u384(x: int | PyFelt, as_hex=True) -> str:
limbs = bigint_split(x, 4, 2**96)
if as_hex:
Expand All @@ -111,6 +126,22 @@ def int_array_to_u384_array(x: list[int] | list[PyFelt], const=False) -> str:
return f"array![{', '.join([int_to_u384(i) for i in x])}]"


def int_array_to_u288_array(x: list[int] | list[PyFelt], const=False) -> str:
if const:
return f"[{', '.join([int_to_u288(i) for i in x])}]"
else:
return f"array![{', '.join([int_to_u288(i) for i in x])}]"


def int_array_to_u2XX_array(
x: list[int] | list[PyFelt], curve_id: int, const=False
) -> str:
if curve_id == 1:
return int_array_to_u384_array(x, const)
else:
return int_array_to_u288_array(x, const)


def bigint_pack(x: object, n_limbs: int, base: int) -> int:
val = 0
for i in range(n_limbs):
Expand Down
141 changes: 130 additions & 11 deletions hydra/garaga/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from garaga.algebra import FunctionFelt, ModuloCircuitElement, PyFelt
from garaga.definitions import STARK, G1Point, G2Point, get_base_field
from garaga.hints import io
from garaga.hints.io import int_array_to_u384_array, int_to_u256, int_to_u384
from garaga.hints.io import (
int_array_to_u384_array,
int_to_u2XX,
int_to_u256,
int_to_u384,
)

T = TypeVar("T", bound="Cairo1SerializableStruct")

Expand Down Expand Up @@ -45,6 +50,10 @@ def __post_init__(self):
def struct_name(self) -> str:
return self.__class__.__name__

@property
def bits(self) -> int:
return self.elmts[0].p.bit_length()

def serialize_input_signature(self) -> str:
return f"{self.name}:{self.struct_name}"

Expand Down Expand Up @@ -217,6 +226,36 @@ def _serialize_to_calldata(self, option: CairoOption = None) -> list[int]:
return cd


# class u288(Cairo1SerializableStruct):
# def serialize(self, raw: bool = False) -> str:
# assert len(self.elmts) == 1
# raw_struct = f"{int_to_u288(self.elmts[0].value)}"
# if raw:
# return raw_struct
# else:
# return f"let {self.name}:{self.struct_name} = {raw_struct};\n"

# def _serialize_to_calldata(self) -> list[int]:
# assert len(self.elmts) == 1
# return io.bigint_split_array(self.elmts, n_limbs=3, prepend_length=False)

# def extract_from_circuit_output(
# self, offset_to_reference_map: dict[int, str]
# ) -> str:
# assert len(self.elmts) == 1
# return f"let {self.name}:{self.struct_name} = outputs.get_output({offset_to_reference_map[self.elmts[0].offset]});"

# def dump_to_circuit_input(self) -> str:
# return f"circuit_inputs = circuit_inputs.next_u288({self.name});\n"

# def __len__(self) -> int:
# if self.elmts is not None:
# assert len(self.elmts) == 1
# return 1
# else:
# return 1


class u384(Cairo1SerializableStruct):
def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 1
Expand Down Expand Up @@ -375,18 +414,41 @@ def __len__(self) -> int:

class u384Array(Cairo1SerializableStruct):
def serialize(self, raw: bool = False) -> str:
raw_struct = f"{int_array_to_u384_array(self.elmts)}"
if len(self.elmts) == 0:
raw_struct = "array![]"
else:
bits = self.bits
if self.name == "g_rhs_sqrt":
# Temp fix before we change the MSMHint and G1Points to depend on the curve id
# Todo : remove this
curve_id = 1
else:
curve_id = 0 if bits <= 288 else 1

raw_struct = f"{io.int_array_to_u2XX_array(self.elmts, curve_id=curve_id)}"
if raw:
return raw_struct
else:
return f"let {self.name}:{self.struct_name} = {raw_struct};\n"

def _serialize_to_calldata(self) -> list[int]:
return io.bigint_split_array(self.elmts, prepend_length=True)
if len(self.elmts) == 0:
return [0]
bits = self.bits
if bits <= 288 and self.name != "g_rhs_sqrt":
return io.bigint_split_array(self.elmts, n_limbs=3, prepend_length=True)
else:
return io.bigint_split_array(self.elmts, n_limbs=4, prepend_length=True)

@property
def struct_name(self) -> str:
return "Array<u384>"
bits = self.bits
if bits <= 288:
return "Array<u288>"
elif bits <= 384:
return "Array<u384>"
else:
raise ValueError(f"Unsupported bit length for u384Array: {bits}")

def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
Expand All @@ -395,10 +457,15 @@ def extract_from_circuit_output(
return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];"

def dump_to_circuit_input(self) -> str:
bits = self.bits
if bits <= 288:
next_fn = "next_u288"
else:
next_fn = "next_2"
code = f"""
let mut {self.name} = {self.name};
while let Option::Some(val) = {self.name}.pop_front() {{
circuit_inputs = circuit_inputs.next_2(val);
circuit_inputs = circuit_inputs.{next_fn}(val);
}};
"""
return code
Expand Down Expand Up @@ -695,9 +762,22 @@ def __init__(self, name: str, elmts: list[ModuloCircuitElement]):
super().__init__(name, elmts)
self.members_names = ("r0a0", "r0a1", "r1a0", "r1a1")

def serialize_input_signature(self) -> str:
bits = self.bits
if bits <= 288:
return f"{self.name}:G2Line<u288>"
else:
return f"{self.name}:G2Line<u384>"

def serialize(self, raw: bool = False) -> str:
assert len(self.elmts) == 4
raw_struct = f"{self.struct_name} {{r0a0: {int_to_u384(self.elmts[0].value)}, r0a1: {int_to_u384(self.elmts[1].value)}, r1a0: {int_to_u384(self.elmts[2].value)}, r1a1: {int_to_u384(self.elmts[3].value)}}}"
bits = self.bits
if bits <= 288:
curve_id = 0
else:
curve_id = 1

raw_struct = f"{self.struct_name} {{r0a0: {int_to_u2XX(self.elmts[0].value, curve_id=curve_id)}, r0a1: {int_to_u2XX(self.elmts[1].value, curve_id=curve_id)}, r1a0: {int_to_u2XX(self.elmts[2].value, curve_id=curve_id)}, r1a1: {int_to_u2XX(self.elmts[3].value, curve_id=curve_id)}}}"
if raw:
return raw_struct
else:
Expand All @@ -711,8 +791,15 @@ def extract_from_circuit_output(

def dump_to_circuit_input(self) -> str:
code = ""
bits = self.bits
if bits <= 288:
next_fn = "next_u288"
else:
next_fn = "next_2"
for mem_name in self.members_names:
code += f"circuit_inputs = circuit_inputs.next_2({self.name}.{mem_name});\n"
code += (
f"circuit_inputs = circuit_inputs.{next_fn}({self.name}.{mem_name});\n"
)
return code

def __len__(self) -> int:
Expand Down Expand Up @@ -803,6 +890,14 @@ def extract_from_circuit_output(
code += "};"
return code

@property
def struct_name(self) -> str:
p = self.elmts[0].p
if p.bit_length() <= 288:
return "E12D<u288>"
else:
return "E12D<u384>"

def serialize(self, raw: bool = False, is_option: bool = False) -> str:
if self.elmts is None:
raw_struct = "Option::None"
Expand All @@ -812,7 +907,16 @@ def serialize(self, raw: bool = False, is_option: bool = False) -> str:
return f"let {self.name}:Option<{self.__class__.__name__}> = {raw_struct};\n"
else:
assert len(self.elmts) == 12
raw_struct = f"{self.__class__.__name__}{{{','.join([f'w{i}: {int_to_u384(self.elmts[i].value)}' for i in range(len(self))])}}}"
bits: int = self.elmts[0].p.bit_length()
if bits <= 288:
curve_id = 0
else:
curve_id = 1

raw_struct = (
f"{self.__class__.__name__}{{"
+ f"{','.join([f'w{i}: {int_to_u2XX(self.elmts[i].value, curve_id=curve_id)}' for i in range(len(self))])}}}"
)
if is_option:
raw_struct = f"Option::Some({raw_struct})"
if raw:
Expand All @@ -821,12 +925,27 @@ def serialize(self, raw: bool = False, is_option: bool = False) -> str:
return f"let {self.name} = {raw_struct};\n"

def _serialize_to_calldata(self) -> list[int]:
return io.bigint_split_array(self.elmts, prepend_length=False)
bits: int = self.bits
if bits <= 288:
return io.bigint_split_array(self.elmts, n_limbs=3, prepend_length=False)
elif bits <= 384:
return io.bigint_split_array(self.elmts, n_limbs=4, prepend_length=False)
else:
raise ValueError(f"Unsupported bit length for E12D: {bits}")

def dump_to_circuit_input(self) -> str:
bits: int = self.elmts[0].p.bit_length()
code = ""
for i in range(len(self)):
code += f"circuit_inputs = circuit_inputs.next_2({self.name}.w{i});\n"
if bits <= 288:
for i in range(len(self)):
code += (
f"circuit_inputs = circuit_inputs.next_u288({self.name}.w{i});\n"
)
elif bits <= 384:
for i in range(len(self)):
code += f"circuit_inputs = circuit_inputs.next_2({self.name}.w{i});\n"
else:
raise ValueError(f"Unsupported bit length for E12D: {bits}")
return code

def __len__(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def compilation_mode_to_file_header(mode: int) -> str:
use garaga::core::circuit::AddInputResultTrait2;
use core::circuit::CircuitElement as CE;
use core::circuit::CircuitInput as CI;
use garaga::definitions::{get_a, get_b, get_p, get_g, get_min_one, G1Point, G2Point, E12D, E12DMulQuotient, G1G2Pair, BNProcessedPair, BLSProcessedPair, MillerLoopResultScalingFactor, G2Line};
use garaga::definitions::{get_a, get_b, get_p, get_g, get_min_one, G1Point, G2Point, E12D, u288, E12DMulQuotient, G1G2Pair, BNProcessedPair, BLSProcessedPair, MillerLoopResultScalingFactor, G2Line};
use garaga::ec_ops::{SlopeInterceptOutput, FunctionFeltEvaluations, FunctionFelt};
use core::option::Option;\n
"""
Expand Down
22 changes: 17 additions & 5 deletions hydra/garaga/precompiled_circuits/multi_pairing_check.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from garaga.definitions import (
CURVES,
CurveID,
G1G2Pair,
G1Point,
Expand Down Expand Up @@ -408,13 +409,24 @@ def get_pairing_check_input(

assert n_pairs >= 2, "n_pairs must be >= 2 for pairing checks"
field = get_base_field(curve_id.value)
p = G1Point.gen_random_point(curve_id)
q = G2Point.gen_random_point(curve_id)
if n_pairs == 2:
# Generate inputs resembling BLS signature verification
curve = CURVES[curve_id.value]
secret_key = field.random(curve.n).value
public_key = G2Point.get_nG(curve_id, secret_key)
message_hash = G1Point.gen_random_point(curve_id)
signature = message_hash.scalar_mul(secret_key)

P = [signature, message_hash]
Q = [G2Point.get_nG(curve_id, 1), -public_key]
else:
p = G1Point.gen_random_point(curve_id)
q = G2Point.gen_random_point(curve_id)

P = [p] * n_pairs
Q = [q] * n_pairs
P = [p] * n_pairs
Q = [q] * n_pairs

P[-1] = p.scalar_mul(-(n_pairs - 1))
P[-1] = p.scalar_mul(-(n_pairs - 1))
c_input = []
for p, q in zip(P, Q):
c_input.append(field(p.x))
Expand Down
3 changes: 2 additions & 1 deletion hydra/garaga/starknet/groth16_contract_generator/calldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def groth16_calldata_from_vk_and_proof(
)
)

return calldata
# return calldata
return [len(calldata)] + calldata


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 7e16413

Please sign in to comment.