From 6dc81ce5c6623c4034beaf7fe2219462b8f1870b Mon Sep 17 00:00:00 2001 From: feltroid Prime Date: Thu, 14 Mar 2024 19:34:52 +0100 Subject: [PATCH] update transcript with rust binding --- src/poseidon_transcript.py | 128 +++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 71 deletions(-) diff --git a/src/poseidon_transcript.py b/src/poseidon_transcript.py index 030a8686..b603eb92 100644 --- a/src/poseidon_transcript.py +++ b/src/poseidon_transcript.py @@ -1,23 +1,26 @@ from starkware.cairo.common.poseidon_utils import ( PoseidonParams, - hades_permutation, + hades_permutation as hades_permutation_slow, ) ##only for testing times from src.hints.io import bigint_split -from src.definitions import N_LIMBS, BASE +from src.definitions import N_LIMBS, BASE, STARK from src.algebra import PyFelt, ModuloCircuitElement -import sys -import os +import hades_binding -# Add the directory containing 'hades_binding' to sys.path -script_dir = os.path.dirname(__file__) # Directory of the current script -parent_dir = os.path.dirname(script_dir) # Parent directory -tools_dir = os.path.join(parent_dir, "tools") # Path to the 'tools' directory -if tools_dir not in sys.path: - sys.path.insert(0, tools_dir) -import hades_binding +def hades_permutation(s0: int, s1: int, s2: int) -> tuple[int, int, int]: + r0, r1, r2 = hades_binding.hades_permutation( + (s0 % STARK).to_bytes(32, "big"), + (s1 % STARK).to_bytes(32, "big"), + (s2 % STARK).to_bytes(32, "big"), + ) + return ( + int.from_bytes(r0, "big"), + int.from_bytes(r1, "big"), + int.from_bytes(r2, "big"), + ) class CairoPoseidonTranscript: @@ -27,10 +30,13 @@ class CairoPoseidonTranscript: """ def __init__(self, init_hash: int) -> None: - self.params = PoseidonParams.get_default_poseidon_params() - self.continuable_hash = init_hash - self.s1 = None - self.permutations_count = 0 + self.init_hash = init_hash + self.s0, self.s1, self.s2 = hades_permutation( + init_hash, + 0, + 1, + ) + self.permutations_count = 1 self.poseidon_ptr_indexes = [] self.z = None @@ -47,82 +53,62 @@ def RLC_coeff(self): self.poseidon_ptr_indexes.append(self.permutations_count - 1) return self.s1 - def hash_value(self, x: int): - x_temp = x - continuable_hash_temp = self.continuable_hash - s0_bytes, s1_bytes, s2_bytes = hades_binding.hades_permutation( - x_temp.to_bytes(32, byteorder="big"), - continuable_hash_temp.to_bytes(32, byteorder="big"), - (2).to_bytes(32, byteorder="big"), + def hash_element(self, x: PyFelt | ModuloCircuitElement): + # print(f"Will Hash PYTHON {hex(x.value)}") + limbs = bigint_split(x.value, N_LIMBS, BASE) + self.s0, self.s1, self.s2 = hades_permutation( + self.s0 + limbs[0] + (BASE) * limbs[1], + self.s1 + limbs[2] + (BASE) * limbs[3], + self.s2, ) - s0 = int.from_bytes(s0_bytes, "big") - s1 = int.from_bytes(s1_bytes, "big") - self.continuable_hash = s0 - self.s1 = s1 self.permutations_count += 1 return self.s0, self.s1 def hash_limbs_multi( self, X: list[PyFelt | ModuloCircuitElement], sparsity: list[int] = None - ) -> tuple[int, int]: + ): if sparsity: X = [x for i, x in enumerate(X) if sparsity[i] != 0] for X_elem in X: - # print(f"Will Hash PYTHON {hex(X_elem.value)}") - limbs = bigint_split(X_elem.value, N_LIMBS, BASE) - for i in range(0, N_LIMBS, 2): - combined_limbs = limbs[i] * limbs[i + 1] - self.hash_value(combined_limbs) - return self.continuable_hash, self.s1 - - def test(self): - return hades_permutation([1, 3, 2], self.params) - - # def generate_poseidon_assertions( - # self, - # continuable_hash_name: str, - # num_pairs: int, - # ) -> str: - # cairo_code = "" - # for i in range(num_pairs): - # s0_index = i * 2 - # s1_index = s0_index + 1 - # if i == 0: - # s1_previous_output = continuable_hash_name - # else: - # s1_previous_output = f"poseidon_ptr[{i-1}].output.s0" - # cairo_code += ( - # f" assert poseidon_ptr[{i}].input = PoseidonBuiltinState(\n" - # f" s0=range_check96_ptr[{s0_index}] * range_check96_ptr[{s1_index}], " - # f"s1={s1_previous_output}, s2=two\n" - # " );\n" - # ) - # return cairo_code - - -import time -from timeit import default_timer as timer + self.hash_element(X_elem) + return None + if __name__ == "__main__": + import time + import random + + print("Running hades binding test against reference implementation") - print("Running...") - transcript = CairoPoseidonTranscript(init_hash=0) + params = PoseidonParams.get_default_poseidon_params() + + random.seed(0) + n_tests = 10000 + for i in range(n_tests): + x0, x1, x2 = ( + random.randint(0, STARK - 1), + random.randint(0, STARK - 1), + random.randint(0, STARK - 1), + ) + ref0, ref1, ref2 = hades_permutation_slow([x0, x1, x2], params) + test0, test1, test2 = hades_permutation(x0, x1, x2) + assert ref0 == test0 and ref1 == test1 and ref2 == test2 + + print(f"{n_tests} random tests passed!") + + print("Running performance test...") start_time = time.time() for i in range(0, 10000): - transcript.test() + x = hades_permutation_slow([1, 2, 3], params) end_time = time.time() execution_time1 = (end_time - start_time) / 10000 start_time = time.time() for i in range(0, 10000): - hades_binding.hades_permutation( - (1).to_bytes(32, byteorder="big"), - (3).to_bytes(32, byteorder="big"), - (2).to_bytes(32, byteorder="big"), - ) + x = hades_permutation(1, 2, 3) end_time = time.time() execution_time2 = (end_time - start_time) / 10000 - print(f"hades_permutation execution time Python: {execution_time1} seconds") - print(f"hades_permutation execution time rust: {execution_time2} seconds") + print(f"hades_permutation execution time Python: {execution_time1:2f} seconds") + print(f"hades_permutation execution time rust: {execution_time2:2f} seconds")