From 6135bd6dec063851a8249e1dafc8a7e4c98abf9d Mon Sep 17 00:00:00 2001 From: fatalbatros Date: Mon, 16 Sep 2024 18:48:49 -0300 Subject: [PATCH] Refactored `from_json` methods of Groth16Proof and Groth16VerifyingKeys classes (#199) Co-authored-by: casiojapi --- .../parsing_utils.py | 115 ++++++++++-------- 1 file changed, 64 insertions(+), 51 deletions(-) diff --git a/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py b/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py index 9e21cb5a..93134f7a 100644 --- a/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py +++ b/hydra/garaga/starknet/groth16_contract_generator/parsing_utils.py @@ -196,11 +196,8 @@ def __post_init__(self): def curve_id(self) -> CurveID: return self.alpha.curve_id - def from_json(file_path: str | Path) -> "Groth16VerifyingKey": - path = Path(file_path) + def from_dict(data: dict) -> "Groth16VerifyingKey": try: - with path.open("r") as f: - data = json.load(f) curve_id = try_guessing_curve_id_from_json(data) try: verifying_key = find_item_from_key_patterns(data, ["verifying_key"]) @@ -237,6 +234,14 @@ def from_json(file_path: str | Path) -> "Groth16VerifyingKey": for point in find_item_from_key_patterns(g1_points, ["K"]) ], ) + except KeyError as e: + raise KeyError(f"The key {e} is missing from the JSON data.") + + def from_json(file_path: str | Path) -> "Groth16VerifyingKey": + path = Path(file_path) + try: + with path.open("r") as f: + data = json.load(f) except FileNotFoundError: cwd = os.getcwd() print(f"Current working directory: {cwd}") @@ -244,8 +249,7 @@ def from_json(file_path: str | Path) -> "Groth16VerifyingKey": raise FileNotFoundError(f"The file {file_path} was not found.") except json.JSONDecodeError: raise ValueError(f"The file {file_path} does not contain valid JSON.") - except KeyError as e: - raise KeyError(f"The key {e} is missing from the JSON data.") + return Groth16VerifyingKey.from_dict(data) def serialize_to_cairo(self) -> str: # Precompute M = miller_loop(public_pair) @@ -308,6 +312,49 @@ def __post_init__(self): ), f"All points must be on the same curve, got {self.a.curve_id}, {self.b.curve_id}, {self.c.curve_id}" self.curve_id = self.a.curve_id + def from_dict( + data: dict, public_inputs: None | list | dict = None + ) -> "Groth16Proof": + curve_id = try_guessing_curve_id_from_json(data) + try: + proof = find_item_from_key_patterns(data, ["proof"]) + except ValueError: + proof = data + + try: + seal = io.to_hex_str(find_item_from_key_patterns(data, ["seal"])) + image_id = io.to_hex_str(find_item_from_key_patterns(data, ["image_id"])) + journal = io.to_hex_str(find_item_from_key_patterns(data, ["journal"])) + + return Groth16Proof._from_risc0( + seal=bytes.fromhex(seal[2:]), + image_id=bytes.fromhex(image_id[2:]), + journal=bytes.fromhex(journal[2:]), + ) + except ValueError: + pass + except KeyError: + pass + except Exception as e: + print(f"Error: {e}") + raise e + + if public_inputs is not None: + if isinstance(public_inputs, dict): + public_inputs = list(public_inputs.values()) + elif isinstance(public_inputs, list): + pass + else: + raise ValueError(f"Invalid public inputs format: {public_inputs}") + else: + public_inputs = find_item_from_key_patterns(data, ["public"]) + return Groth16Proof( + a=try_parse_g1_point_from_key(proof, ["a"], curve_id), + b=try_parse_g2_point_from_key(proof, ["b"], curve_id), + c=try_parse_g1_point_from_key(proof, ["c", "Krs"], curve_id), + public_inputs=[io.to_int(pub) for pub in public_inputs], + ) + def from_json( proof_path: str | Path, public_inputs_path: str | Path = None ) -> "Groth16Proof": @@ -315,57 +362,23 @@ def from_json( try: with path.open("r") as f: data = json.load(f) - # print(f"data: {data}") - # print(f"data.keys(): {data.keys()}") - curve_id = try_guessing_curve_id_from_json(data) - - try: - proof = find_item_from_key_patterns(data, ["proof"]) - except ValueError: - proof = data - - try: - seal = io.to_hex_str(find_item_from_key_patterns(data, ["seal"])) - image_id = io.to_hex_str( - find_item_from_key_patterns(data, ["image_id"]) - ) - journal = io.to_hex_str(find_item_from_key_patterns(data, ["journal"])) - - return Groth16Proof._from_risc0( - seal=bytes.fromhex(seal[2:]), - image_id=bytes.fromhex(image_id[2:]), - journal=bytes.fromhex(journal[2:]), - ) - except ValueError: - pass - except KeyError: - pass - except Exception as e: - print(f"Error: {e}") - raise e - + except FileNotFoundError: + raise FileNotFoundError(f"The file {proof_path} was not found.") + except json.JSONDecodeError: + raise ValueError(f"The file {proof_path} does not contain valid JSON.") + try: if public_inputs_path is not None: with Path(public_inputs_path).open("r") as f: public_inputs = json.load(f) - print(f"public_inputs: {public_inputs}") - if isinstance(public_inputs, dict): - public_inputs = list(public_inputs.values()) - elif isinstance(public_inputs, list): - pass - else: - raise ValueError(f"Invalid public inputs format: {public_inputs}") else: - public_inputs = find_item_from_key_patterns(data, ["public"]) - return Groth16Proof( - a=try_parse_g1_point_from_key(proof, ["a"], curve_id), - b=try_parse_g2_point_from_key(proof, ["b"], curve_id), - c=try_parse_g1_point_from_key(proof, ["c", "Krs"], curve_id), - public_inputs=[io.to_int(pub) for pub in public_inputs], - ) + public_inputs = None except FileNotFoundError: - raise FileNotFoundError(f"The file {proof_path} was not found.") + raise FileNotFoundError(f"The file {public_inputs_path} was not found.") except json.JSONDecodeError: - raise ValueError(f"The file {proof_path} does not contain valid JSON.") + raise ValueError( + f"The file {public_inputs_path} does not contain valid JSON." + ) + return Groth16Proof.from_dict(data, public_inputs) def _from_risc0( seal: bytes,