From d310f40dfa89f3d02c211ed96f2d9cca274998e9 Mon Sep 17 00:00:00 2001 From: feltroidprime Date: Thu, 5 Sep 2024 12:36:24 +0200 Subject: [PATCH] handle errors gracefully in binding --- tools/garaga_rs/src/algebra/g1point.rs | 23 +++++++----- tools/garaga_rs/src/definitions.rs | 36 ++++++++++--------- tools/garaga_rs/src/ecip/core.rs | 4 +-- tools/garaga_rs/src/io.rs | 4 +-- tools/garaga_rs/src/msm.rs | 29 ++++++++++----- .../garaga_rs/src/python_bindings/extf_mul.rs | 2 +- tools/garaga_rs/src/wasm_bindings.rs | 32 ++++++++++++----- 7 files changed, 82 insertions(+), 48 deletions(-) diff --git a/tools/garaga_rs/src/algebra/g1point.rs b/tools/garaga_rs/src/algebra/g1point.rs index 2f6eaf2f..e2d861bf 100644 --- a/tools/garaga_rs/src/algebra/g1point.rs +++ b/tools/garaga_rs/src/algebra/g1point.rs @@ -9,19 +9,23 @@ pub struct G1Point { } impl> G1Point { - pub fn new(x: FieldElement, y: FieldElement) -> Self { + pub fn new(x: FieldElement, y: FieldElement) -> Result { let point = Self { x: x.clone(), y: y.clone(), }; if !point.is_infinity() && !point.is_on_curve() { - panic!( + return Err(format!( "Point ({:?}, {:?}) is not on the curve", x.representative().to_string(), y.representative().to_string() - ); + )); } - point + Ok(point) + } + + pub fn new_unchecked(x: FieldElement, y: FieldElement) -> Self { + Self { x, y } } pub fn is_infinity(&self) -> bool { @@ -37,7 +41,7 @@ impl> G1Point { } if self.x == other.x && self.y != other.y { - return G1Point::new(FieldElement::::zero(), FieldElement::::zero()); + return G1Point::new_unchecked(FieldElement::::zero(), FieldElement::::zero()); } let lambda = if self.eq(other) { @@ -52,14 +56,14 @@ impl> G1Point { let x3 = lambda.square() - self.x.clone() - other.x.clone(); let y3 = lambda * (self.x.clone() - x3.clone()) - self.y.clone(); - G1Point::new(x3, y3) + G1Point::new_unchecked(x3, y3) } pub fn neg(&self) -> Self { if self.is_infinity() { self.clone() } else { - G1Point::new(self.x.clone(), -self.y.clone()) + G1Point::new_unchecked(self.x.clone(), -self.y.clone()) } } @@ -73,10 +77,11 @@ impl> G1Point { return self.clone(); } if scalar == BigInt::ZERO { - return G1Point::new(FieldElement::::zero(), FieldElement::::zero()); + return G1Point::new_unchecked(FieldElement::::zero(), FieldElement::::zero()); } - let mut result = G1Point::new(FieldElement::::zero(), FieldElement::::zero()); + let mut result = + G1Point::new_unchecked(FieldElement::::zero(), FieldElement::::zero()); let mut base = self.clone(); //println!("scalar mul scalar: {:?}", scalar); diff --git a/tools/garaga_rs/src/definitions.rs b/tools/garaga_rs/src/definitions.rs index 71850bb4..ec257a56 100644 --- a/tools/garaga_rs/src/definitions.rs +++ b/tools/garaga_rs/src/definitions.rs @@ -21,28 +21,32 @@ pub enum CurveID { X25519 = 4, } -impl From for CurveID { - fn from(value: u8) -> Self { +impl TryFrom for CurveID { + type Error = String; + + fn try_from(value: u8) -> Result { match value { - 0 => CurveID::BN254, - 1 => CurveID::BLS12_381, - 2 => CurveID::SECP256K1, - 3 => CurveID::SECP256R1, - 4 => CurveID::X25519, - _ => panic!("Invalid curve ID"), + 0 => Ok(CurveID::BN254), + 1 => Ok(CurveID::BLS12_381), + 2 => Ok(CurveID::SECP256K1), + 3 => Ok(CurveID::SECP256R1), + 4 => Ok(CurveID::X25519), + _ => Err(format!("Invalid curve ID: {}", value)), } } } -impl From for CurveID { - fn from(value: usize) -> Self { +impl TryFrom for CurveID { + type Error = String; + + fn try_from(value: usize) -> Result { match value { - 0 => CurveID::BN254, - 1 => CurveID::BLS12_381, - 2 => CurveID::SECP256K1, - 3 => CurveID::SECP256R1, - 4 => CurveID::X25519, - _ => panic!("Invalid curve ID"), + 0 => Ok(CurveID::BN254), + 1 => Ok(CurveID::BLS12_381), + 2 => Ok(CurveID::SECP256K1), + 3 => Ok(CurveID::SECP256R1), + 4 => Ok(CurveID::X25519), + _ => Err(format!("Invalid curve ID: {}", value)), } } } diff --git a/tools/garaga_rs/src/ecip/core.rs b/tools/garaga_rs/src/ecip/core.rs index cbdc8082..90830f61 100644 --- a/tools/garaga_rs/src/ecip/core.rs +++ b/tools/garaga_rs/src/ecip/core.rs @@ -42,7 +42,7 @@ where FieldElement: ByteConversion, { let elements = parse_fn(&values); - let points = parse_g1_points_from_flattened_field_elements_list(&elements); + let points = parse_g1_points_from_flattened_field_elements_list(&elements)?; let (q, sum_dlog) = run_ecip(&points, &scalars); Ok(prepare_result(&q, &sum_dlog)) } @@ -295,7 +295,7 @@ fn ecip_functions>( ) -> (G1Point, Vec>) { let mut dss = dss; dss.reverse(); - let mut q = G1Point::new(FieldElement::zero(), FieldElement::zero()); + let mut q = G1Point::new_unchecked(FieldElement::zero(), FieldElement::zero()); let mut divisors: Vec> = Vec::new(); for ds in dss.iter() { let (div, new_q) = row_function(ds.clone(), bs, q); diff --git a/tools/garaga_rs/src/io.rs b/tools/garaga_rs/src/io.rs index 744e1e8e..dacce8e5 100644 --- a/tools/garaga_rs/src/io.rs +++ b/tools/garaga_rs/src/io.rs @@ -6,14 +6,14 @@ use num_bigint::BigUint; pub fn parse_g1_points_from_flattened_field_elements_list( values: &[FieldElement], -) -> Vec> +) -> Result>, String> where F: IsPrimeField + CurveParamsProvider, { values .chunks(2) .map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone())) - .collect() + .collect::, _>>() } pub fn field_elements_from_big_uints(values: &[BigUint]) -> Vec> diff --git a/tools/garaga_rs/src/msm.rs b/tools/garaga_rs/src/msm.rs index 98ad1b7e..b19d3c22 100644 --- a/tools/garaga_rs/src/msm.rs +++ b/tools/garaga_rs/src/msm.rs @@ -26,9 +26,11 @@ pub fn msm_calldata_builder( values: &[BigUint], scalars: &[BigUint], curve_id: usize, -) -> Vec { - assert_eq!(values.len(), 2 * scalars.len()); - let curve_id = CurveID::from(curve_id); +) -> Result, String> { + if values.len() != 2 * scalars.len() { + return Err("Values length must be twice the scalars length".to_string()); + } + let curve_id = CurveID::try_from(curve_id)?; match curve_id { CurveID::BN254 => handle_curve::(values, scalars, curve_id as usize), CurveID::BLS12_381 => { @@ -44,18 +46,24 @@ pub fn msm_calldata_builder( } } -fn handle_curve(values: &[BigUint], scalars: &[BigUint], curve_id: usize) -> Vec +fn handle_curve( + values: &[BigUint], + scalars: &[BigUint], + curve_id: usize, +) -> Result, String> where F: IsPrimeField + CurveParamsProvider, FieldElement: ByteConversion, { let elements = field_elements_from_big_uints::(values); - let points = parse_g1_points_from_flattened_field_elements_list(&elements); + let points = parse_g1_points_from_flattened_field_elements_list(&elements)?; let n = &element_to_biguint(&F::get_curve_params().n); if !scalars.iter().all(|x| x < n) { - panic!("Scalar value must be less than the curve order"); + return Err("Scalar value must be less than the curve order".to_string()); } - calldata_builder(&points, scalars, curve_id, true, true, false) + Ok(calldata_builder( + &points, scalars, curve_id, true, true, false, + )) } pub fn calldata_builder>( @@ -288,7 +296,10 @@ where attempt += 1; } let y = sqrt(&rhs); - (G1Point::new(felt252_to_element(&x_252), y), g_rhs_roots) + ( + G1Point::new_unchecked(felt252_to_element(&x_252), y), + g_rhs_roots, + ) } fn sqrt(value: &FieldElement) -> FieldElement @@ -740,7 +751,7 @@ mod tests { .iter() .map(|s| BigInt::parse_bytes(s.as_bytes(), 10).unwrap()) .collect::>(); - let result = msm_calldata_builder(&values, &scalars, CurveID::BN254 as usize); + let result = msm_calldata_builder(&values, &scalars, CurveID::BN254 as usize).unwrap(); assert_eq!(result, expected); } } diff --git a/tools/garaga_rs/src/python_bindings/extf_mul.rs b/tools/garaga_rs/src/python_bindings/extf_mul.rs index faffb47e..cc2abac3 100644 --- a/tools/garaga_rs/src/python_bindings/extf_mul.rs +++ b/tools/garaga_rs/src/python_bindings/extf_mul.rs @@ -16,7 +16,7 @@ pub fn nondeterministic_extension_field_mul_divmod( .into_iter() .map(|x| x.extract()) .collect::>, _>>()?; - let curve_id = CurveID::from(curve_id); + let curve_id = CurveID::try_from(curve_id).unwrap(); match curve_id { CurveID::BN254 => { handle_extension_field_mul_divmod::(py, ext_degree, list_coeffs) diff --git a/tools/garaga_rs/src/wasm_bindings.rs b/tools/garaga_rs/src/wasm_bindings.rs index 7e8af53d..1baa6823 100644 --- a/tools/garaga_rs/src/wasm_bindings.rs +++ b/tools/garaga_rs/src/wasm_bindings.rs @@ -7,16 +7,30 @@ pub fn msm_calldata_builder( values: Vec, scalars: Vec, curve_id: usize, -) -> Vec { - let values: Vec = values.into_iter().map(jsvalue_to_biguint).collect(); - let scalars: Vec = scalars.into_iter().map(jsvalue_to_biguint).collect(); - let result = crate::msm::msm_calldata_builder(&values, &scalars, curve_id); - result.into_iter().map(bigint_to_jsvalue).collect() +) -> Result, JsValue> { + let values: Vec = values + .into_iter() + .map(jsvalue_to_biguint) + .collect::, _>>()?; + let scalars: Vec = scalars + .into_iter() + .map(jsvalue_to_biguint) + .collect::, _>>()?; + + // Ensure msm_calldata_builder returns a Result type + let result = crate::msm::msm_calldata_builder(&values, &scalars, curve_id) + .map_err(|e| JsValue::from_str(&e.to_string()))?; // Handle error here + + let result: Vec = result; // Ensure result is of type Vec + + Ok(result.into_iter().map(bigint_to_jsvalue).collect()) } -fn jsvalue_to_biguint(v: JsValue) -> BigUint { - let s = (JsValue::from_str("") + v).as_string().unwrap(); - BigUint::from_str(&s).expect("Failed to convert value to non-negative bigint") +fn jsvalue_to_biguint(v: JsValue) -> Result { + let s = (JsValue::from_str("") + v) + .as_string() + .ok_or_else(|| JsValue::from_str("Failed to convert JsValue to string"))?; + BigUint::from_str(&s).map_err(|_| JsValue::from_str("Failed to convert string to BigUint")) } fn bigint_to_jsvalue(v: BigInt) -> JsValue { @@ -35,7 +49,7 @@ mod tests { pub fn test_bigint_marshalling() { let v = 31415usize; assert_eq!( - jsvalue_to_biguint(bigint_to_jsvalue(BigInt::from(v))), + jsvalue_to_biguint(bigint_to_jsvalue(BigInt::from(v))).unwrap(), BigUint::from(v) ); }