Skip to content

Commit

Permalink
update transcript with rust binding
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Mar 14, 2024
1 parent 63e56b0 commit 6dc81ce
Showing 1 changed file with 57 additions and 71 deletions.
128 changes: 57 additions & 71 deletions src/poseidon_transcript.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from starkware.cairo.common.poseidon_utils import (
PoseidonParams,
hades_permutation,
hades_permutation as hades_permutation_slow,
) ##only for testing times
from src.hints.io import bigint_split
from src.definitions import N_LIMBS, BASE
from src.definitions import N_LIMBS, BASE, STARK
from src.algebra import PyFelt, ModuloCircuitElement


import sys
import os
import hades_binding

# Add the directory containing 'hades_binding' to sys.path
script_dir = os.path.dirname(__file__) # Directory of the current script
parent_dir = os.path.dirname(script_dir) # Parent directory
tools_dir = os.path.join(parent_dir, "tools") # Path to the 'tools' directory
if tools_dir not in sys.path:
sys.path.insert(0, tools_dir)

import hades_binding
def hades_permutation(s0: int, s1: int, s2: int) -> tuple[int, int, int]:
r0, r1, r2 = hades_binding.hades_permutation(
(s0 % STARK).to_bytes(32, "big"),
(s1 % STARK).to_bytes(32, "big"),
(s2 % STARK).to_bytes(32, "big"),
)
return (
int.from_bytes(r0, "big"),
int.from_bytes(r1, "big"),
int.from_bytes(r2, "big"),
)


class CairoPoseidonTranscript:
Expand All @@ -27,10 +30,13 @@ class CairoPoseidonTranscript:
"""

def __init__(self, init_hash: int) -> None:
self.params = PoseidonParams.get_default_poseidon_params()
self.continuable_hash = init_hash
self.s1 = None
self.permutations_count = 0
self.init_hash = init_hash
self.s0, self.s1, self.s2 = hades_permutation(
init_hash,
0,
1,
)
self.permutations_count = 1
self.poseidon_ptr_indexes = []
self.z = None

Expand All @@ -47,82 +53,62 @@ def RLC_coeff(self):
self.poseidon_ptr_indexes.append(self.permutations_count - 1)
return self.s1

def hash_value(self, x: int):
x_temp = x
continuable_hash_temp = self.continuable_hash
s0_bytes, s1_bytes, s2_bytes = hades_binding.hades_permutation(
x_temp.to_bytes(32, byteorder="big"),
continuable_hash_temp.to_bytes(32, byteorder="big"),
(2).to_bytes(32, byteorder="big"),
def hash_element(self, x: PyFelt | ModuloCircuitElement):
# print(f"Will Hash PYTHON {hex(x.value)}")
limbs = bigint_split(x.value, N_LIMBS, BASE)
self.s0, self.s1, self.s2 = hades_permutation(
self.s0 + limbs[0] + (BASE) * limbs[1],
self.s1 + limbs[2] + (BASE) * limbs[3],
self.s2,
)
s0 = int.from_bytes(s0_bytes, "big")
s1 = int.from_bytes(s1_bytes, "big")
self.continuable_hash = s0
self.s1 = s1
self.permutations_count += 1

return self.s0, self.s1

def hash_limbs_multi(
self, X: list[PyFelt | ModuloCircuitElement], sparsity: list[int] = None
) -> tuple[int, int]:
):
if sparsity:
X = [x for i, x in enumerate(X) if sparsity[i] != 0]
for X_elem in X:
# print(f"Will Hash PYTHON {hex(X_elem.value)}")
limbs = bigint_split(X_elem.value, N_LIMBS, BASE)
for i in range(0, N_LIMBS, 2):
combined_limbs = limbs[i] * limbs[i + 1]
self.hash_value(combined_limbs)
return self.continuable_hash, self.s1

def test(self):
return hades_permutation([1, 3, 2], self.params)

# def generate_poseidon_assertions(
# self,
# continuable_hash_name: str,
# num_pairs: int,
# ) -> str:
# cairo_code = ""
# for i in range(num_pairs):
# s0_index = i * 2
# s1_index = s0_index + 1
# if i == 0:
# s1_previous_output = continuable_hash_name
# else:
# s1_previous_output = f"poseidon_ptr[{i-1}].output.s0"
# cairo_code += (
# f" assert poseidon_ptr[{i}].input = PoseidonBuiltinState(\n"
# f" s0=range_check96_ptr[{s0_index}] * range_check96_ptr[{s1_index}], "
# f"s1={s1_previous_output}, s2=two\n"
# " );\n"
# )
# return cairo_code


import time
from timeit import default_timer as timer
self.hash_element(X_elem)
return None


if __name__ == "__main__":
import time
import random

print("Running hades binding test against reference implementation")

print("Running...")
transcript = CairoPoseidonTranscript(init_hash=0)
params = PoseidonParams.get_default_poseidon_params()

random.seed(0)
n_tests = 10000
for i in range(n_tests):
x0, x1, x2 = (
random.randint(0, STARK - 1),
random.randint(0, STARK - 1),
random.randint(0, STARK - 1),
)
ref0, ref1, ref2 = hades_permutation_slow([x0, x1, x2], params)
test0, test1, test2 = hades_permutation(x0, x1, x2)
assert ref0 == test0 and ref1 == test1 and ref2 == test2

print(f"{n_tests} random tests passed!")

print("Running performance test...")
start_time = time.time()
for i in range(0, 10000):
transcript.test()
x = hades_permutation_slow([1, 2, 3], params)
end_time = time.time()
execution_time1 = (end_time - start_time) / 10000

start_time = time.time()
for i in range(0, 10000):
hades_binding.hades_permutation(
(1).to_bytes(32, byteorder="big"),
(3).to_bytes(32, byteorder="big"),
(2).to_bytes(32, byteorder="big"),
)
x = hades_permutation(1, 2, 3)
end_time = time.time()
execution_time2 = (end_time - start_time) / 10000

print(f"hades_permutation execution time Python: {execution_time1} seconds")
print(f"hades_permutation execution time rust: {execution_time2} seconds")
print(f"hades_permutation execution time Python: {execution_time1:2f} seconds")
print(f"hades_permutation execution time rust: {execution_time2:2f} seconds")

0 comments on commit 6dc81ce

Please sign in to comment.