Skip to content

Commit

Permalink
handle errors gracefully in binding
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Sep 5, 2024
1 parent c548762 commit d310f40
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 48 deletions.
23 changes: 14 additions & 9 deletions tools/garaga_rs/src/algebra/g1point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@ pub struct G1Point<F: IsPrimeField> {
}

impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {
pub fn new(x: FieldElement<F>, y: FieldElement<F>) -> Self {
pub fn new(x: FieldElement<F>, y: FieldElement<F>) -> Result<Self, String> {
let point = Self {
x: x.clone(),
y: y.clone(),
};
if !point.is_infinity() && !point.is_on_curve() {
panic!(
return Err(format!(
"Point ({:?}, {:?}) is not on the curve",
x.representative().to_string(),
y.representative().to_string()
);
));
}
point
Ok(point)
}

pub fn new_unchecked(x: FieldElement<F>, y: FieldElement<F>) -> Self {
Self { x, y }
}

pub fn is_infinity(&self) -> bool {
Expand All @@ -37,7 +41,7 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {
}

if self.x == other.x && self.y != other.y {
return G1Point::new(FieldElement::<F>::zero(), FieldElement::<F>::zero());
return G1Point::new_unchecked(FieldElement::<F>::zero(), FieldElement::<F>::zero());
}

let lambda = if self.eq(other) {
Expand All @@ -52,14 +56,14 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {
let x3 = lambda.square() - self.x.clone() - other.x.clone();
let y3 = lambda * (self.x.clone() - x3.clone()) - self.y.clone();

G1Point::new(x3, y3)
G1Point::new_unchecked(x3, y3)
}

pub fn neg(&self) -> Self {
if self.is_infinity() {
self.clone()
} else {
G1Point::new(self.x.clone(), -self.y.clone())
G1Point::new_unchecked(self.x.clone(), -self.y.clone())
}
}

Expand All @@ -73,10 +77,11 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {
return self.clone();
}
if scalar == BigInt::ZERO {
return G1Point::new(FieldElement::<F>::zero(), FieldElement::<F>::zero());
return G1Point::new_unchecked(FieldElement::<F>::zero(), FieldElement::<F>::zero());
}

let mut result = G1Point::new(FieldElement::<F>::zero(), FieldElement::<F>::zero());
let mut result =
G1Point::new_unchecked(FieldElement::<F>::zero(), FieldElement::<F>::zero());
let mut base = self.clone();

//println!("scalar mul scalar: {:?}", scalar);
Expand Down
36 changes: 20 additions & 16 deletions tools/garaga_rs/src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,32 @@ pub enum CurveID {
X25519 = 4,
}

impl From<u8> for CurveID {
fn from(value: u8) -> Self {
impl TryFrom<u8> for CurveID {
type Error = String;

fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => CurveID::BN254,
1 => CurveID::BLS12_381,
2 => CurveID::SECP256K1,
3 => CurveID::SECP256R1,
4 => CurveID::X25519,
_ => panic!("Invalid curve ID"),
0 => Ok(CurveID::BN254),
1 => Ok(CurveID::BLS12_381),
2 => Ok(CurveID::SECP256K1),
3 => Ok(CurveID::SECP256R1),
4 => Ok(CurveID::X25519),
_ => Err(format!("Invalid curve ID: {}", value)),
}
}
}

impl From<usize> for CurveID {
fn from(value: usize) -> Self {
impl TryFrom<usize> for CurveID {
type Error = String;

fn try_from(value: usize) -> Result<Self, Self::Error> {
match value {
0 => CurveID::BN254,
1 => CurveID::BLS12_381,
2 => CurveID::SECP256K1,
3 => CurveID::SECP256R1,
4 => CurveID::X25519,
_ => panic!("Invalid curve ID"),
0 => Ok(CurveID::BN254),
1 => Ok(CurveID::BLS12_381),
2 => Ok(CurveID::SECP256K1),
3 => Ok(CurveID::SECP256R1),
4 => Ok(CurveID::X25519),
_ => Err(format!("Invalid curve ID: {}", value)),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tools/garaga_rs/src/ecip/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where
FieldElement<F>: ByteConversion,
{
let elements = parse_fn(&values);
let points = parse_g1_points_from_flattened_field_elements_list(&elements);
let points = parse_g1_points_from_flattened_field_elements_list(&elements)?;
let (q, sum_dlog) = run_ecip(&points, &scalars);
Ok(prepare_result(&q, &sum_dlog))
}
Expand Down Expand Up @@ -295,7 +295,7 @@ fn ecip_functions<F: IsPrimeField + CurveParamsProvider<F>>(
) -> (G1Point<F>, Vec<FF<F>>) {
let mut dss = dss;
dss.reverse();
let mut q = G1Point::new(FieldElement::zero(), FieldElement::zero());
let mut q = G1Point::new_unchecked(FieldElement::zero(), FieldElement::zero());
let mut divisors: Vec<FF<F>> = Vec::new();
for ds in dss.iter() {
let (div, new_q) = row_function(ds.clone(), bs, q);
Expand Down
4 changes: 2 additions & 2 deletions tools/garaga_rs/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use num_bigint::BigUint;

pub fn parse_g1_points_from_flattened_field_elements_list<F>(
values: &[FieldElement<F>],
) -> Vec<G1Point<F>>
) -> Result<Vec<G1Point<F>>, String>
where
F: IsPrimeField + CurveParamsProvider<F>,
{
values
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect()
.collect::<Result<Vec<_>, _>>()
}

pub fn field_elements_from_big_uints<F>(values: &[BigUint]) -> Vec<FieldElement<F>>
Expand Down
29 changes: 20 additions & 9 deletions tools/garaga_rs/src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ pub fn msm_calldata_builder(
values: &[BigUint],
scalars: &[BigUint],
curve_id: usize,
) -> Vec<BigInt> {
assert_eq!(values.len(), 2 * scalars.len());
let curve_id = CurveID::from(curve_id);
) -> Result<Vec<BigInt>, String> {
if values.len() != 2 * scalars.len() {
return Err("Values length must be twice the scalars length".to_string());
}
let curve_id = CurveID::try_from(curve_id)?;
match curve_id {
CurveID::BN254 => handle_curve::<BN254PrimeField>(values, scalars, curve_id as usize),
CurveID::BLS12_381 => {
Expand All @@ -44,18 +46,24 @@ pub fn msm_calldata_builder(
}
}

fn handle_curve<F>(values: &[BigUint], scalars: &[BigUint], curve_id: usize) -> Vec<BigInt>
fn handle_curve<F>(
values: &[BigUint],
scalars: &[BigUint],
curve_id: usize,
) -> Result<Vec<BigInt>, String>
where
F: IsPrimeField + CurveParamsProvider<F>,
FieldElement<F>: ByteConversion,
{
let elements = field_elements_from_big_uints::<F>(values);
let points = parse_g1_points_from_flattened_field_elements_list(&elements);
let points = parse_g1_points_from_flattened_field_elements_list(&elements)?;
let n = &element_to_biguint(&F::get_curve_params().n);
if !scalars.iter().all(|x| x < n) {
panic!("Scalar value must be less than the curve order");
return Err("Scalar value must be less than the curve order".to_string());
}
calldata_builder(&points, scalars, curve_id, true, true, false)
Ok(calldata_builder(
&points, scalars, curve_id, true, true, false,
))
}

pub fn calldata_builder<F: IsPrimeField + CurveParamsProvider<F>>(
Expand Down Expand Up @@ -288,7 +296,10 @@ where
attempt += 1;
}
let y = sqrt(&rhs);
(G1Point::new(felt252_to_element(&x_252), y), g_rhs_roots)
(
G1Point::new_unchecked(felt252_to_element(&x_252), y),
g_rhs_roots,
)
}

fn sqrt<F>(value: &FieldElement<F>) -> FieldElement<F>
Expand Down Expand Up @@ -740,7 +751,7 @@ mod tests {
.iter()
.map(|s| BigInt::parse_bytes(s.as_bytes(), 10).unwrap())
.collect::<Vec<BigInt>>();
let result = msm_calldata_builder(&values, &scalars, CurveID::BN254 as usize);
let result = msm_calldata_builder(&values, &scalars, CurveID::BN254 as usize).unwrap();
assert_eq!(result, expected);
}
}
2 changes: 1 addition & 1 deletion tools/garaga_rs/src/python_bindings/extf_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn nondeterministic_extension_field_mul_divmod(
.into_iter()
.map(|x| x.extract())
.collect::<Result<Vec<Vec<BigUint>>, _>>()?;
let curve_id = CurveID::from(curve_id);
let curve_id = CurveID::try_from(curve_id).unwrap();
match curve_id {
CurveID::BN254 => {
handle_extension_field_mul_divmod::<BN254PrimeField>(py, ext_degree, list_coeffs)
Expand Down
32 changes: 23 additions & 9 deletions tools/garaga_rs/src/wasm_bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,30 @@ pub fn msm_calldata_builder(
values: Vec<JsValue>,
scalars: Vec<JsValue>,
curve_id: usize,
) -> Vec<JsValue> {
let values: Vec<BigUint> = values.into_iter().map(jsvalue_to_biguint).collect();
let scalars: Vec<BigUint> = scalars.into_iter().map(jsvalue_to_biguint).collect();
let result = crate::msm::msm_calldata_builder(&values, &scalars, curve_id);
result.into_iter().map(bigint_to_jsvalue).collect()
) -> Result<Vec<JsValue>, JsValue> {
let values: Vec<BigUint> = values
.into_iter()
.map(jsvalue_to_biguint)
.collect::<Result<Vec<_>, _>>()?;
let scalars: Vec<BigUint> = scalars
.into_iter()
.map(jsvalue_to_biguint)
.collect::<Result<Vec<_>, _>>()?;

// Ensure msm_calldata_builder returns a Result type
let result = crate::msm::msm_calldata_builder(&values, &scalars, curve_id)
.map_err(|e| JsValue::from_str(&e.to_string()))?; // Handle error here

let result: Vec<BigInt> = result; // Ensure result is of type Vec<BigInt>

Ok(result.into_iter().map(bigint_to_jsvalue).collect())
}

fn jsvalue_to_biguint(v: JsValue) -> BigUint {
let s = (JsValue::from_str("") + v).as_string().unwrap();
BigUint::from_str(&s).expect("Failed to convert value to non-negative bigint")
fn jsvalue_to_biguint(v: JsValue) -> Result<BigUint, JsValue> {
let s = (JsValue::from_str("") + v)
.as_string()
.ok_or_else(|| JsValue::from_str("Failed to convert JsValue to string"))?;
BigUint::from_str(&s).map_err(|_| JsValue::from_str("Failed to convert string to BigUint"))
}

fn bigint_to_jsvalue(v: BigInt) -> JsValue {
Expand All @@ -35,7 +49,7 @@ mod tests {
pub fn test_bigint_marshalling() {
let v = 31415usize;
assert_eq!(
jsvalue_to_biguint(bigint_to_jsvalue(BigInt::from(v))),
jsvalue_to_biguint(bigint_to_jsvalue(BigInt::from(v))).unwrap(),
BigUint::from(v)
);
}
Expand Down

0 comments on commit d310f40

Please sign in to comment.