diff --git a/hydra/algebra.py b/hydra/algebra.py index ec4bfdb0..4386902a 100644 --- a/hydra/algebra.py +++ b/hydra/algebra.py @@ -3,6 +3,12 @@ import random from dataclasses import dataclass +from sympy import legendre_symbol, sqrt_mod + +from typing import Generic, TypeVar + +T = TypeVar("T", "PyFelt", "Fp2") + @dataclass(slots=True, frozen=True) class PyFelt: @@ -61,7 +67,11 @@ def __rmul__(self, left: PyFelt | int) -> PyFelt: return self.__mul__(left) def __inv__(self) -> PyFelt: - return PyFelt(pow(self.value, -1, self.p), self.p) + try: + inv = pow(self.value, -1, self.p) + except ValueError: + raise ValueError(f"Cannot invert {self.value} modulo {self.p}") + return PyFelt(inv, self.p) def __truediv__(self, right: PyFelt) -> PyFelt: assert type(self) == type(right), f"Cannot divide {type(self)} by {type(right)}" @@ -129,6 +139,173 @@ def __rsub__(self, left: PyFelt | int) -> PyFelt: def __rtruediv__(self, left: PyFelt | int) -> PyFelt: return self.__inv__().__mul__(left) + def is_quad_residue(self) -> bool: + return legendre_symbol(self.value, self.p) == 1 + + def sqrt(self) -> PyFelt: + if not self.is_quad_residue(): + raise ValueError("Cannot square root a non-quadratic residue") + return PyFelt(min(sqrt_mod(self.value, self.p, all_roots=True)), self.p) + + +@dataclass(slots=True) +class Fp2: + a0: PyFelt + a1: PyFelt + + def __post_init__(self): + assert self.a0.p == self.a1.p, "Fields must be the same" + + @property + def p(self) -> int: + return self.a0.p + + @staticmethod + def random(p: int, max_value: int = None) -> Fp2: + if max_value is None: + max_value = p - 1 + + rnd1 = random.randint(0, max_value) + rnd2 = random.randint(0, max_value) + return Fp2( + PyFelt(rnd1, p), + PyFelt(rnd2, p), + ) + + @staticmethod + def one(p: int) -> Fp2: + return Fp2(PyFelt(1, p), PyFelt(0, p)) + + @staticmethod + def zero(p: int) -> Fp2: + return Fp2(PyFelt(0, p), PyFelt(0, p)) + + def __repr__(self) -> str: + return f"Fp2({self.a0}, {self.a1})" + + def __add__(self, other: Fp2) -> Fp2: + if isinstance(other, Fp2): + return Fp2(self.a0 + other.a0, self.a1 + other.a1) + else: + raise TypeError(f"Cannot add Fp2 and {type(other)}") + + def __eq__(self, other: object) -> bool: + if isinstance(other, Fp2): + return self.a0 == other.a0 and self.a1 == other.a1 and self.p == other.p + else: + raise TypeError(f"Cannot compare Fp2 and {type(other)}") + + def __neg__(self) -> Fp2: + return Fp2(-self.a0, -self.a1) + + def __sub__(self, other: Fp2) -> Fp2: + return self.__add__(-other) + + def __mul__(self, other: Fp2 | PyFelt | int) -> Fp2: + if isinstance(other, PyFelt): + assert other.p == self.a0.p, "Fields must be the same" + return Fp2(self.a0 * other, self.a1 * other) + elif isinstance(other, int): + return Fp2(self.a0 * other, self.a1 * other) + elif isinstance(other, Fp2): + # (a0 + a1 * i) * (b0 + b1 * i) = a0 * b0 - a1 * b1 + (a0 * b1 + a1 * b0) * i + return Fp2( + self.a0 * other.a0 - self.a1 * other.a1, + self.a0 * other.a1 + self.a1 * other.a0, + ) + else: + raise TypeError(f"Cannot multiply Fp2 and {type(other)}") + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + if isinstance(other, Fp2): + return self * other.__inv__() + elif isinstance(other, int): + return self * pow(other, -1, self.p) + + return NotImplemented + + def __rtruediv__(self, other) -> Fp2: + if isinstance(other, Fp2): + return other * self.__inv__() + elif isinstance(other, int): + return other * self.__inv__() + + return NotImplemented + + def __inv__(self) -> Fp2: + t0, t1 = (self.a0 * self.a0, self.a1 * self.a1) + t0 = t0 + t1 + t1 = pow(t0.value, -1, self.p) + return Fp2(self.a0 * t1, -(self.a1 * t1)) + + def __pow__(self, p: int) -> Fp2: + """ + Compute x**p in F_p^2 using square-and-multiply algorithm. + Args: + p: The exponent, a non-negative integer. + Returns: + x**p in F_p^2, represented similarly as x. + """ + assert isinstance(p, int) and p >= 0 + + # Handle the easy cases. + if p == 0: + # x**0 = 1, where 1 is the multiplicative identity in F_p^2. + return Fp2(PyFelt(1, self.p), PyFelt(0, self.p)) + elif p == 1: + # x**1 = x. + return self + + # Start the computation. + result = self.one( + self.p + ) # Initialize result as the multiplicative identity in F_p^2. + temp = self # Initialize temp as self. + + # Loop through each bit of the exponent p. + for bit in reversed(bin(p)[2:]): # [2:] to strip the "0b" prefix. + if bit == "1": + result = result * temp + temp = temp * temp + + return result + + def norm(self) -> PyFelt: + return self.a0 * self.a0 + self.a1 * self.a1 + + def legendre(self) -> int: + norm = self.norm() + return legendre_symbol(norm.value, self.p) + + def is_quad_residue(self) -> bool: + return self.legendre() == 1 + + def sqrt(self) -> Fp2: + if not self.is_quad_residue(): + raise ValueError("Cannot square root a non-quadratic residue") + assert self.p % 4 == 3, "p must be 3 mod 4 to use this sqrt" + min_one = Fp2(PyFelt(-1 % self.p, self.p), PyFelt(0, self.p)) + + a = self + a1 = a ** ((self.p - 3) // 4) + alpha = a1 * a1 * a + a0 = alpha**self.p * alpha + if a0 == min_one: + return ValueError("Cannot square root a non-quadratic residue") + + x0 = a1 * a + if alpha == min_one: + i = Fp2(PyFelt(0, self.p), PyFelt(1, self.p)) + x = i * x0 + else: + b = (Fp2.one(self.p) + alpha) ** ((self.p - 1) // 2) + x = b * x0 + + return x + @dataclass(slots=True) class BaseField: @@ -146,6 +323,35 @@ def one(self) -> PyFelt: def random(self) -> PyFelt: return PyFelt(random.randint(0, self.p - 1), self.p) + @property + def type(self) -> type[PyFelt]: + return PyFelt + + +@dataclass(slots=True) +class BaseFp2Field: + p: int + + def __call__(self, a: tuple[int, int] | int) -> Fp2: + if isinstance(a, tuple): + a0, a1 = a + else: + a0, a1 = a, 0 + return Fp2(PyFelt(a0 % self.p, self.p), PyFelt(a1 % self.p, self.p)) + + def zero(self) -> Fp2: + return Fp2(PyFelt(0, self.p), PyFelt(0, self.p)) + + def one(self) -> Fp2: + return Fp2(PyFelt(1, self.p), PyFelt(0, self.p)) + + def random(self) -> Fp2: + return Fp2.random(self.p) + + @property + def type(self) -> type[Fp2]: + return Fp2 + @dataclass(slots=True, frozen=True) class ModuloCircuitElement: @@ -173,7 +379,7 @@ def felt(self) -> PyFelt: return self.emulated_felt -class Polynomial: +class Polynomial(Generic[T]): """ Represents a polynomial with coefficients in a finite field. @@ -199,15 +405,29 @@ class Polynomial: def __init__( self, - coefficients: list[PyFelt | ModuloCircuitElement], + coefficients: list[T], ): - self.coefficients: list[PyFelt] = [c.felt for c in coefficients] - self.p = coefficients[0].p - self.field = BaseField(self.p) + if all(isinstance(c, (ModuloCircuitElement, PyFelt)) for c in coefficients): + self.coefficients: list[PyFelt] = [c.felt for c in coefficients] + self.type = PyFelt + self.p = coefficients[0].p + self.field = BaseField(self.p) + elif all(isinstance(c, Fp2) for c in coefficients): + self.coefficients: list[Fp2] = coefficients + self.type = Fp2 + self.p = coefficients[0].p + self.field = BaseFp2Field(self.p) + else: + raise TypeError( + f"All elements in the list must be of the same type, either ModuloCircuitElement, PyFelt or Fp2., got {set([type(c) for c in coefficients])}" + ) return def __repr__(self) -> str: - return f"Polynomial({[x.value for x in self.get_coeffs()]})" + if self.type == PyFelt: + return f"Polynomial({[x.value for x in self.get_coeffs()]})" + elif self.type == Fp2: + return f"Polynomial({[x for x in self.get_coeffs()]})" def print_as_sage_poly(self, var_name: str = "z", as_hex: bool = False) -> str: """ @@ -215,13 +435,18 @@ def print_as_sage_poly(self, var_name: str = "z", as_hex: bool = False) -> str: """ if self.is_zero(): return "" - coeffs = self.get_value_coeffs() + coeffs = self.get_coeffs() string = "" + zero = self.field.zero() for i, coeff in enumerate(coeffs[::-1]): - if coeff == 0: + if coeff == zero: continue else: - coeff_str = hex(coeff) if as_hex else str(coeff) + if self.type == PyFelt: + coeff_str = hex(coeff.value) if as_hex else str(coeff.value) + elif self.type == Fp2: + coeff_str = f"({coeff.a1.value} * i + {coeff.a0.value})" + if i == len(coeffs) - 1: string += f"{coeff_str}" @@ -242,20 +467,23 @@ def __len__(self) -> int: def degree(self) -> int: for i in range(len(self.coefficients) - 1, -1, -1): - if self.coefficients[i].value != 0: + if self.coefficients[i] != self.field.zero(): return i return -1 - def get_coeffs(self) -> list[PyFelt]: + def get_coeffs(self) -> list[T]: coeffs = self.coefficients.copy() - while len(coeffs) > 0 and coeffs[-1] == 0: + while len(coeffs) > 0 and coeffs[-1] == self.field.zero(): coeffs.pop() if coeffs == []: return [self.field.zero()] return coeffs def get_value_coeffs(self) -> list[int]: - return [c.value for c in self.get_coeffs()] + if self.type == PyFelt: + return [c.value for c in self.get_coeffs()] + elif self.type == Fp2: + raise NotImplementedError("Fp2 not implemented") def differentiate(self) -> "Polynomial": """ @@ -276,6 +504,10 @@ def differentiate(self) -> "Polynomial": def __add__(self, other: Polynomial) -> Polynomial: if not isinstance(other, Polynomial): raise TypeError(f"Cannot add Polynomial and {type(other)}") + if self.type != other.type: + raise TypeError( + f"Cannot add Polynomial of type {self.type} and {other.type} \n self: {self} \n other: {other}" + ) ns, no = len(self.coefficients), len(other.coefficients) if ns >= no: @@ -300,17 +532,23 @@ def __mul__( ) -> "Polynomial": if isinstance(other, (PyFelt, ModuloCircuitElement)): return Polynomial([c * other.felt for c in self.coefficients]) + elif isinstance(other, Fp2): + return Polynomial([c * other for c in self.coefficients]) elif not isinstance(other, Polynomial): raise TypeError( f"Cannot multiply polynomial by type {type(other)}, must be PyFelt or Polynomial" ) + if self.type != other.type: + raise TypeError( + f"Cannot multiply polynomial of type {self.type} by polynomial of type {other.type}" + ) if self.coefficients == [] or other.coefficients == []: return Polynomial([self.field.zero()]) zero = self.field.zero() buf = [zero] * (len(self.coefficients) + len(other.coefficients) - 1) for i in range(len(self.coefficients)): - if self.coefficients[i] == 0: + if self.coefficients[i] == self.field.zero(): continue # optimization for sparse polynomials for j in range(len(other.coefficients)): buf[i + j] = buf[i + j] + self.coefficients[i] * other.coefficients[j] @@ -341,7 +579,7 @@ def __divmod__(self, denominator: "Polynomial") -> tuple[Polynomial, Polynomial] if denominator.is_zero(): raise ValueError("Cannot divide by zero polynomial") if self.degree() < denominator.degree(): - return (Polynomial([PyFelt(0, self.p)]), self) + return (Polynomial.zero(self.p, self.type), self) field = self.field remainder = Polynomial([n for n in self.coefficients]) quotient_coefficients = [ @@ -384,19 +622,30 @@ def is_zero(self) -> bool: if not self.coefficients: return True for c in self.coefficients: - if c.value != 0: + if c != self.field.zero(): return False return True @staticmethod - def zero(p: int) -> "Polynomial": - return Polynomial([PyFelt(0, p)]) + def zero(p: int, type: type[T] = PyFelt) -> "Polynomial[T]": + if type == PyFelt: + return Polynomial([PyFelt(0, p)]) + elif type == Fp2: + return Polynomial([Fp2.zero(p)]) + else: + raise ValueError(f"Unknown type {type}") @staticmethod - def one(p: int) -> "Polynomial": - return Polynomial([PyFelt(1, p)]) + def one(p: int, type: type[T] = PyFelt) -> "Polynomial[T]": + if type == PyFelt: + return Polynomial([PyFelt(1, p)]) + elif type == Fp2: + return Polynomial([Fp2.one(p)]) + else: + raise ValueError(f"Unknown type {type}") - def evaluate(self, point: PyFelt) -> PyFelt: + def evaluate(self, point: PyFelt | Fp2) -> PyFelt | Fp2: + assert type(point) == self.type, "point type must match polynomial type" xi = self.field.one() value = self.field.zero() for c in self.coefficients: @@ -415,9 +664,14 @@ def __pow__(self, exponent: int) -> "Polynomial": return acc def pow(self, exponent: int, modulo_poly: "Polynomial") -> "Polynomial": + if self.type != modulo_poly.type: + raise TypeError( + f"Cannot pow polynomial of type {self.type} modulo a polynomial of type {modulo_poly.type}" + ) + one = Polynomial.one(self.p, self.type) if exponent == 0: - return Polynomial([PyFelt(1, self.coefficients[0].p)]) - acc = Polynomial([PyFelt(1, self.coefficients[0].p)]) + return one + acc = one for i in reversed(range(len(bin(exponent)[2:]))): acc = acc * acc % modulo_poly if (1 << i) & exponent != 0: @@ -462,8 +716,8 @@ def xgcd(x: Polynomial, y: Polynomial) -> tuple[Polynomial, Polynomial, Polynomi b (Polynomial): A polynomial such that a * x + b * y = g. g (Polynomial): The greatest common divisor of x and y. """ - one = Polynomial([x.field.one()]) - zero = Polynomial([x.field.zero()]) + one = Polynomial.one(x.p, x.type) + zero = Polynomial.zero(x.p, x.type) old_r, r = (x, y) old_s, s = (one, zero) old_t, t = (zero, one) @@ -520,12 +774,12 @@ def lagrange_interpolation( @dataclass(slots=True) -class RationalFunction: - numerator: Polynomial - denominator: Polynomial +class RationalFunction(Generic[T]): + numerator: Polynomial[T] + denominator: Polynomial[T] @property - def field(self) -> BaseField: + def field(self) -> BaseField | BaseFp2Field: return self.numerator.field def simplify(self) -> "RationalFunction": @@ -552,7 +806,7 @@ def __mul__(self, other: int | PyFelt) -> "RationalFunction": raise TypeError(f"Cannot multiply RationalFunction with {type(other)}") return RationalFunction(self.numerator * other, self.denominator) - def evaluate(self, x: PyFelt) -> PyFelt: + def evaluate(self, x: PyFelt | Fp2) -> PyFelt | Fp2: return self.numerator.evaluate(x) / self.denominator.evaluate(x) def degrees_infos(self) -> dict[str, int]: @@ -563,13 +817,13 @@ def degrees_infos(self) -> dict[str, int]: @dataclass(slots=True) -class FunctionFelt: +class FunctionFelt(Generic[T]): # f = a(x) + yb(x) - a: RationalFunction - b: RationalFunction + a: RationalFunction[T] + b: RationalFunction[T] @property - def field(self) -> BaseField: + def field(self) -> BaseField | BaseFp2Field: return self.a.numerator.field def simplify(self) -> "FunctionFelt": @@ -584,7 +838,14 @@ def __mul__(self, other: PyFelt | int) -> "FunctionFelt": def __rmul__(self, other: PyFelt | int) -> "FunctionFelt": return self.__mul__(other) - def evaluate(self, x: PyFelt, y: PyFelt) -> PyFelt: + def evaluate(self, x: PyFelt | Fp2, y: PyFelt | Fp2) -> PyFelt | Fp2: + assert ( + type(x) == self.field.type and x.p == self.field.p + ), f"x type must match field {self.field.type}, got {type(x)} over {hex(x.p)}" + assert ( + type(y) == self.field.type and y.p == self.field.p + ), f"y type must match field {self.field.type}, got {type(y)} over {hex(y.p)}" + return self.a.evaluate(x) + y * self.b.evaluate(x) def degrees_infos(self) -> dict[str, dict[str, int]]: @@ -593,6 +854,14 @@ def degrees_infos(self) -> dict[str, dict[str, int]]: "b": self.b.degrees_infos(), } + def validate_degrees(self, msm_size: int) -> bool: + degrees = self.degrees_infos() + assert degrees["a"]["numerator"] <= msm_size + 1 + assert degrees["a"]["denominator"] <= msm_size + 2 + assert degrees["b"]["numerator"] <= msm_size + 2 + assert degrees["b"]["denominator"] <= msm_size + 5 + return True + def print_as_sage_poly(self, var: str = "x") -> str: return f"(({self.b.numerator.print_as_sage_poly(var)}) / ({self.b.denominator.print_as_sage_poly(var)}) * y + ({self.a.numerator.print_as_sage_poly(var)} / ({self.a.denominator.print_as_sage_poly(var)})" @@ -603,3 +872,26 @@ def print_as_sage_poly(self, var: str = "x") -> str: values = [PyFelt(2, p), PyFelt(4, p)] print(Polynomial.lagrange_interpolation(p, domain, values)) print(PyFelt(1, 12345864586489789789)) + + from hydra.definitions import CURVES, get_base_field, STARK, G2Point, CurveID + + curve_index = 0 + p = CURVES[curve_index].p + a = PyFelt(CURVES[curve_index].a, p) + b_fp2 = Fp2(PyFelt(CURVES[curve_index].b20, p), PyFelt(CURVES[curve_index].b21, p)) + + n = 10000 + quad_residue_count = 0 + for _ in range(n): + x = Fp2.random(p, STARK) + xA = x * a + y2 = x * x * x + xA + b_fp2 + if y2.is_quad_residue(): + quad_residue_count += 1 + y = y2.sqrt() + assert y * y == y2, f"y^2 != y2: {y * y} != {y2}" + pt = G2Point( + (x.a0.value, x.a1.value), (y.a0.value, y.a1.value), CurveID(curve_index) + ) + # print(pt) + print(f"Quadratic residue count: {quad_residue_count} / {n}") diff --git a/hydra/definitions.py b/hydra/definitions.py index ebc1d330..9a91440a 100644 --- a/hydra/definitions.py +++ b/hydra/definitions.py @@ -3,9 +3,22 @@ from dataclasses import dataclass from enum import Enum -from starkware.python.math_utils import EcInfinity, ec_safe_add, ec_safe_mult - -from hydra.algebra import BaseField, ModuloCircuitElement, Polynomial, PyFelt +from starkware.python.math_utils import ( + ec_safe_mult, + EcInfinity, + ec_safe_add, + EC_INFINITY, +) +from starkware.python.math_utils import is_quad_residue, sqrt as sqrt_mod_p + +from hydra.algebra import ( + BaseField, + ModuloCircuitElement, + Polynomial, + PyFelt, + Fp2, + BaseFp2Field, +) from hydra.hints.io import bigint_split, int_to_u256, int_to_u384 N_LIMBS = 4 @@ -271,10 +284,17 @@ def is_generator(g: int, p: int) -> bool: return True -def get_base_field(curve_id: int | CurveID) -> BaseField: +def get_base_field( + curve_id: int | CurveID, type: PyFelt | Fp2 = PyFelt +) -> BaseField | BaseFp2Field: if isinstance(curve_id, CurveID): curve_id = curve_id.value - return BaseField(CURVES[curve_id].p) + if type == PyFelt: + return BaseField(CURVES[curve_id].p) + elif type == Fp2: + return BaseFp2Field(CURVES[curve_id].p) + else: + raise ValueError("Invalid type") def get_irreducible_poly(curve_id: int, extension_degree: int) -> Polynomial: @@ -310,12 +330,43 @@ def __post_init__(self): if not self.is_on_curve(): raise ValueError(f"Point {self} is not on the curve") + @staticmethod + def infinity(curve_id: CurveID) -> "G1Point": + return G1Point(0, 0, curve_id) + def is_infinity(self) -> bool: return self.x == 0 and self.y == 0 def to_cairo_1(self) -> str: return f"G1Point{{x: {int_to_u384(self.x)}, y: {int_to_u384(self.y)}}};" + @staticmethod + def gen_random_point_not_in_subgroup( + curve_id: CurveID, force_gen: bool = False + ) -> "G1Point": + curve_idx = curve_id.value + if CURVES[curve_idx].h == 1: + if force_gen: + return G1Point.gen_random_point(curve_id) + else: + raise ValueError( + "Cofactor is 1, cannot generate a point not in the subgroup" + ) + else: + field = get_base_field(curve_idx) + while True: + x = field.random() + y2 = x**3 + CURVES[curve_idx].a * x + CURVES[curve_idx].b + try: + tentative_point = G1Point(x.value, y2.sqrt().value, curve_id) + except ValueError: + continue + if tentative_point.is_in_prime_order_subgroup() == False: + return tentative_point + + def is_in_prime_order_subgroup(self) -> bool: + return self.scalar_mul(CURVES[self.curve_id.value].n).is_infinity() + def is_on_curve(self) -> bool: """ Check if the point is on the curve using the curve equation y^2 = x^3 + ax + b. @@ -382,7 +433,7 @@ def scalar_mul(self, scalar: int) -> "G1Point": CURVES[self.curve_id.value].a, CURVES[self.curve_id.value].p, ) - if res == EcInfinity: + if isinstance(res, EcInfinity): return G1Point(0, 0, self.curve_id) return G1Point(res[0], res[1], self.curve_id) @@ -418,9 +469,27 @@ class G2Point: curve_id: CurveID def __post_init__(self): + if self.is_infinity(): + return if not self.is_on_curve(): raise ValueError("Point is not on the curve") + @staticmethod + def infinity(curve_id: CurveID) -> "G2Point": + return G2Point((0, 0), (0, 0), curve_id) + + def __eq__(self, other: "G2Point") -> bool: + return ( + self.x[0] == other.x[0] + and self.x[1] == other.x[1] + and self.y[0] == other.y[0] + and self.y[1] == other.y[1] + and self.curve_id == other.curve_id + ) + + def is_infinity(self) -> bool: + return self.x == (0, 0) and self.y == (0, 0) + def is_on_curve(self) -> bool: """ Check if the point is on the curve using the curve equation y^2 = x^3 + ax + b in the extension field. @@ -452,6 +521,72 @@ def gen_random_point(curve_id: CurveID) -> "G2Point": "G2Point.gen_random_point is not implemented for this curve" ) + @staticmethod + def get_nG(curve_id: CurveID, n: int) -> "G1Point": + """ + Returns the scalar multiplication of the generator point on a given curve by the scalar n. + """ + assert ( + n < CURVES[curve_id.value].n + ), f"n must be less than the order of the curve" + + if curve_id.value in GNARK_CLI_SUPPORTED_CURVES: + from tools.gnark_cli import GnarkCLI + + cli = GnarkCLI(curve_id) + ng1ng2 = cli.nG1nG2_operation(1, n, raw=True) + return G2Point((ng1ng2[2], ng1ng2[3]), (ng1ng2[4], ng1ng2[5]), curve_id) + else: + raise NotImplementedError( + "G2Point.get_nG is not implemented for this curve" + ) + + def scalar_mul(self, scalar: int) -> "G2Point": + if self.is_infinity(): + return self + if scalar == 0: + return G2Point((0, 0), (0, 0), self.curve_id) + if scalar < 0: + return -self.scalar_mul(-scalar) + if self.curve_id.value in GNARK_CLI_SUPPORTED_CURVES: + from tools.gnark_cli import GnarkCLI + + cli = GnarkCLI(self.curve_id) + sP = cli.g2_scalarmul((self.x, self.y), scalar) + return G2Point((sP[0], sP[1]), (sP[2], sP[3]), self.curve_id) + else: + raise NotImplementedError( + "G2Point.scalar_mul is not implemented for this curve" + ) + + def add(self, other: "G2Point") -> "G2Point": + if self.is_infinity(): + return other + if other.is_infinity(): + return self + if self.curve_id != other.curve_id: + raise ValueError("Points are not on the same curve") + if self.curve_id.value in GNARK_CLI_SUPPORTED_CURVES: + from tools.gnark_cli import GnarkCLI + + cli = GnarkCLI(self.curve_id) + sP = cli.g2_add((self.x, self.y), (other.x, other.y)) + return G2Point((sP[0], sP[1]), (sP[2], sP[3]), self.curve_id) + + def __neg__(self) -> "G2Point": + p = CURVES[self.curve_id.value].p + return G2Point( + (self.x[0], self.x[1]), (-self.y[0] % p, -self.y[1] % p), self.curve_id + ) + + @staticmethod + def msm(points: list["G2Point"], scalars: list[int]) -> "G2Point": + assert all(type(p) == G2Point for p in points) + assert len(points) == len(scalars) + muls = [P.scalar_mul(s) for P, s in zip(points, scalars)] + scalar_mul = functools.reduce(lambda acc, p: acc.add(p), muls) + return scalar_mul + @dataclass(slots=True) class G1G2Pair: @@ -641,8 +776,7 @@ def replace_consecutive_zeros(lst): if __name__ == "__main__": from random import randint - from tools.extension_trick import (gnark_to_v, gnark_to_w, v_to_gnark, - w_to_gnark) + from tools.extension_trick import gnark_to_v, gnark_to_w, v_to_gnark, w_to_gnark g1 = G1Point.gen_random_point(CurveID.BLS12_381) g2 = G1Point.gen_random_point(CurveID.BN254) @@ -692,4 +826,16 @@ def replace_consecutive_zeros(lst): from tests.benchmarks import test_msm_n_points - test_msm_n_points(CurveID.BLS12_381, 1) + for i in range(5): + test_msm_n_points(CurveID.BLS12_381, 4) + print(f"i = {i} ok") + + for _ in range(10): + p = G1Point.gen_random_point(CurveID.BLS12_381) + assert p.is_in_prime_order_subgroup() + np = G1Point.gen_random_point_not_in_subgroup(CurveID.BLS12_381) + assert np.is_in_prime_order_subgroup() == False + + q = G2Point.gen_random_point(CurveID.BLS12_381) + print(q) + print(q.scalar_mul(1234567890)) diff --git a/hydra/hints/ecip.py b/hydra/hints/ecip.py index c317959a..7202b83e 100644 --- a/hydra/hints/ecip.py +++ b/hydra/hints/ecip.py @@ -7,41 +7,103 @@ from starkware.python.math_utils import is_quad_residue from starkware.python.math_utils import sqrt as sqrt_mod_p -from hydra.algebra import FunctionFelt, Polynomial, PyFelt, RationalFunction -from hydra.definitions import (CURVES, STARK, CurveID, EcInfinity, G1Point, - get_base_field) +from hydra.algebra import FunctionFelt, Polynomial, PyFelt, Fp2, RationalFunction, T +from hydra.definitions import ( + CURVES, + STARK, + CurveID, + EcInfinity, + G1Point, + G2Point, + get_base_field, +) from hydra.hints.io import int_array_to_u384_array, int_to_u384 -from hydra.hints.neg_3 import construct_digit_vectors +from hydra.hints.neg_3 import ( + construct_digit_vectors, + positive_negative_multiplicities, + neg_3_base_le, +) from hydra.poseidon_transcript import hades_permutation +# Check if a Fp2 element is a quadratic residue in Fp2, assuming the irreducible polynomial is x^2 + 1 +# def is_quad_residue_fp2(x: Fp2) -> bool: + + +def get_field_type_from_ec_point(P) -> type[T]: + if isinstance(P, G1Point): + return PyFelt + elif isinstance(P, G2Point): + return Fp2 + else: + raise ValueError(f"Invalid point type {type(P)}") + + +def get_ec_group_class_from_ec_point(P): + if isinstance(P, G1Point): + return G1Point + elif isinstance(P, G2Point): + return G2Point + else: + raise ValueError(f"Invalid point type {type(P)}") + + def derive_ec_point_from_X( - x: PyFelt | int, curve_id: CurveID -) -> tuple[PyFelt, PyFelt, list[PyFelt]]: + x: PyFelt | int | Fp2, curve_id: CurveID +) -> tuple[PyFelt, PyFelt, list[PyFelt]] | tuple[Fp2, Fp2, list[Fp2]]: field = get_base_field(curve_id.value) if isinstance(x, int): x = field(x) - rhs = x**3 + field(CURVES[curve_id.value].a) * x + field(CURVES[curve_id.value].b) - g = field(CURVES[curve_id.value].fp_generator) + def rhs_compute(x: PyFelt | Fp2) -> PyFelt | Fp2: + if isinstance(x, Fp2): + return ( + x**3 + + x * field(CURVES[curve_id.value].a) + + Fp2( + field(CURVES[curve_id.value].b20), field(CURVES[curve_id.value].b21) + ) + ) + else: + return ( + x**3 + + field(CURVES[curve_id.value].a) * x + + field(CURVES[curve_id.value].b) + ) + + if isinstance(x, Fp2): + g = Fp2( + field(CURVES[curve_id.value].nr_a0), + field(CURVES[curve_id.value].nr_a1), + ) + else: + g = field(CURVES[curve_id.value].fp_generator) + + rhs = rhs_compute(x) g_rhs_roots = [] attempt = 0 - while not is_quad_residue(rhs.value, field.p): + while not rhs.is_quad_residue(): g_rhs = rhs * g - g_rhs_roots.append(sqrt_mod_p(g_rhs.value, field.p)) - _x, _, _ = hades_permutation(x.value, attempt, 2) - x = field(_x) - rhs = ( - x**3 + field(CURVES[curve_id.value].a) * x + field(CURVES[curve_id.value].b) - ) + g_rhs_roots.append(g_rhs.sqrt()) + if isinstance(x, Fp2): + _s0, _, _ = hades_permutation(x.a0.value, x.a1.value, 2) + _x0, _x1, _ = hades_permutation(_s0, attempt, 2) + x = Fp2(field(_x0), field(_x1)) + else: + _x, _, _ = hades_permutation(x.value, attempt, 2) + x = field(_x) + + rhs = rhs_compute(x) attempt += 1 - y = field(sqrt_mod_p(rhs.value, field.p)) + y = rhs.sqrt() assert y**2 == rhs - return x, y, [field(r) for r in g_rhs_roots] + return x, y, g_rhs_roots -def zk_ecip_hint(Bs: list[G1Point], scalars: list[int]) -> tuple[G1Point, FunctionFelt]: +def zk_ecip_hint( + Bs: list[G1Point] | list[G2Point], scalars: list[int] +) -> tuple[G1Point | G2Point, FunctionFelt[T]]: """ Inputs: - Bs: list of points on the curve @@ -64,11 +126,66 @@ def zk_ecip_hint(Bs: list[G1Point], scalars: list[int]) -> tuple[G1Point, Functi return Q, sum_dlog -def slope_intercept(P: G1Point, Q: G1Point) -> tuple[PyFelt, PyFelt]: - field = get_base_field(P.curve_id.value) +def verify_ecip( + Bs: list[G1Point] | list[G2Point], scalars: list[int], A0: G1Point | G2Point = None +) -> bool: + Q, sum_dlog = zk_ecip_hint(Bs, scalars) + assert sum_dlog.validate_degrees(len(Bs)) + epns = [ + positive_negative_multiplicities(neg_3_base_le(scalar)) for scalar in scalars + ] + + c_id = Q.curve_id.value + ec_group_class = get_ec_group_class_from_ec_point(Q) + field_type = get_field_type_from_ec_point(Q) + + field = get_base_field(Q.curve_id.value, field_type) + + if A0 is None: + A0 = ec_group_class.gen_random_point(Q.curve_id) + else: + A0 = A0 + + A2 = A0.scalar_mul(-2) + + xA0 = field(A0.x) + yA0 = field(A0.y) + + xA2 = field(A2.x) + yA2 = field(A2.y) + + mA0, bA0 = slope_intercept(A0, A0) + mA0A2, _ = slope_intercept(A2, A0) + + coeff2 = (2 * yA2 * (xA0 - xA2)) / ( + 3 * xA2**2 + field(CURVES[c_id].a) - 2 * mA0A2 * yA2 + ) + coeff0 = coeff2 + 2 * mA0A2 + + basis_sum = field.zero() + for i, (P, (ep, en)) in enumerate(zip(Bs, epns)): + basis_sum += eval_point_challenge_signed(P, xA0, mA0, bA0, ep, en) + # print(f"rhs_acc {i}: {basis_sum.value}") + + if Q.is_infinity(): + RHS = basis_sum + else: + RHS = eval_point_challenge(-Q, xA0, mA0, bA0, 1) + basis_sum + + LHS = coeff0 * sum_dlog.evaluate(xA0, yA0) - coeff2 * sum_dlog.evaluate(xA2, yA2) + + assert LHS == RHS, f"LHS: {LHS}, RHS: {RHS}" + assert Q == ec_group_class.msm(Bs, scalars) + return True + + +def slope_intercept( + P: G1Point | G2Point, Q: G1Point | G2Point +) -> tuple[PyFelt, PyFelt] | tuple[Fp2, Fp2]: + field = get_base_field(P.curve_id.value, get_field_type_from_ec_point(P)) if P == Q: px, py = field(P.x), field(P.y) - m = (3 * px**2 + CURVES[P.curve_id.value].a) / (2 * py) + m = (3 * px**2 + field(CURVES[P.curve_id.value].a)) / (2 * py) b = py - m * px return (m, b) else: @@ -79,11 +196,48 @@ def slope_intercept(P: G1Point, Q: G1Point) -> tuple[PyFelt, PyFelt]: return (m, b) -def line(P: G1Point, Q: G1Point) -> FF: +def eval_point_challenge( + P: G1Point | G2Point, xA0, mA0, bA0, multiplicity: int +) -> PyFelt | Fp2: + field_type = get_field_type_from_ec_point(P) + field = get_base_field(P.curve_id.value, field_type) + + xP, yP = field(P.x), field(P.y) + num = xA0 - xP + den = yP - mA0 * xP - bA0 + res = multiplicity * num / den + assert type(res) == field_type, f"Expected {field_type}, got {type(res)}" + return res + + +def eval_point_challenge_signed( + P: G1Point | G2Point, xA0, mA0, bA0, ep: int, en: int +) -> PyFelt | Fp2: + return eval_point_challenge(P, xA0, mA0, bA0, ep) + eval_point_challenge( + -P, xA0, mA0, bA0, en + ) + + +def line(P: G1Point | G2Point, Q: G1Point | G2Point) -> FF[T]: """ Returns line function passing through points, works for all points and returns 1 for O + O = O """ - field = get_base_field(P.curve_id.value) + assert ( + P.curve_id == Q.curve_id + ), f"Points must be on the same curve, got {P.curve_id} and {Q.curve_id}" + assert type(P) == type( + Q + ), f"Points must be in the same group, got {type(P)} and {type(Q)}" + + if isinstance(P, G1Point): + field_type = PyFelt + elif isinstance(P, G2Point): + field_type = Fp2 + else: + raise ValueError(f"Invalid point type {type(P)}") + + field = get_base_field(P.curve_id.value, field_type) + if P.is_infinity(): if Q.is_infinity(): return FF([Polynomial([field.one()])], P.curve_id) @@ -97,7 +251,7 @@ def line(P: G1Point, Q: G1Point) -> FF: Px, Py = field(P.x), field(P.y) if P == Q: - m = (3 * Px**2 + CURVES[P.curve_id.value].a) / (2 * Py) + m = (3 * Px**2 + field(CURVES[P.curve_id.value].a)) / (2 * Py) b = Py - m * Px # -m*x + y -b return FF([Polynomial([-b, -m]), Polynomial([field.one()])], P.curve_id) @@ -119,18 +273,24 @@ class FF: Example : F(x, y) = c0(x) + c1(x) * y + c2(x) * y^2 + ... """ - coeffs: list[Polynomial] - y2: Polynomial + coeffs: list[Polynomial[T]] + y2: Polynomial[T] p: int curve_id: CurveID + type: type[T] - def __init__(self, coeffs: list[Polynomial], curve_id: CurveID): + def __init__(self, coeffs: list[Polynomial[T]], curve_id: CurveID): self.coeffs = coeffs self.p = coeffs[0][0].p - self.field = get_base_field(curve_id.value) + self.field = get_base_field(curve_id.value, type(coeffs[0][0])) self.curve_id = curve_id + self.type = type(coeffs[0][0]) a = self.field(CURVES[curve_id.value].a) - b = self.field(CURVES[curve_id.value].b) + if self.type == PyFelt: + b = self.field(CURVES[curve_id.value].b) + else: + b = self.field((CURVES[curve_id.value].b20, CURVES[curve_id.value].b21)) + # y² = x³ + ax + b self.y2 = Polynomial([b, a, self.field.zero(), self.field.one()]) @@ -141,7 +301,7 @@ def __getitem__(self, i: int) -> Polynomial: try: return self.coeffs[i] except IndexError: - return Polynomial.zero(self.p) + return Polynomial.zero(self.p, self.type) def __add__(self, other: FF) -> FF: if not isinstance(other, FF): @@ -170,7 +330,7 @@ def __mul__(self, other: "FF" | Polynomial | PyFelt) -> "FF": if self.coeffs == [] or other.coeffs == []: return FF([Polynomial([self.field.zero()])], self.curve_id) - zero = Polynomial.zero(self.p) + zero = Polynomial.zero(self.p, self.type) buf = [zero] * (len(self.coeffs) + len(other.coeffs) - 1) for i in range(len(self.coeffs)): @@ -198,7 +358,7 @@ def reduce(self) -> "FF": """ if len(self.coeffs) <= 2: while len(self.coeffs) < 2: - self.coeffs.append(Polynomial.zero(self.p)) + self.coeffs.append(Polynomial.zero(self.p, self.type)) return self y2 = self.y2 deg_0_coeff = copy.deepcopy(self.coeffs[0]) @@ -228,7 +388,7 @@ def normalize(self) -> "FF": return FF([c * coeff.__inv__() for c in self.coeffs], self.curve_id) -def construct_function(Ps: list[G1Point]) -> FF: +def construct_function(Ps: list[G1Point] | list[G2Point]) -> FF: """ Returns a function exactly interpolating the points Ps """ @@ -260,10 +420,14 @@ def construct_function(Ps: list[G1Point]) -> FF: return D.normalize() -def row_function(ds: list[int], Ps: list[G1Point], Q: G1Point) -> tuple[FF, G1Point]: +def row_function( + ds: list[int], Ps: list[G1Point] | list[G2Point], Q: G1Point | G2Point +) -> tuple[FF, G1Point | G2Point]: + ec_group_class = G1Point if isinstance(Q, G1Point) else G2Point + infinity = ec_group_class.infinity(Q.curve_id) + digits_points = [ - P if d == 1 else -P if d == -1 else G1Point(0, 0, P.curve_id) - for d, P in zip(ds, Ps) + P if d == 1 else -P if d == -1 else infinity for d, P in zip(ds, Ps) ] sum_digits_points = functools.reduce(lambda x, y: x.add(y), digits_points) Q2 = Q.scalar_mul(-3).add(sum_digits_points) @@ -275,10 +439,13 @@ def row_function(ds: list[int], Ps: list[G1Point], Q: G1Point) -> tuple[FF, G1Po return (D, Q2) -def ecip_functions(Bs: list[G1Point], dss: list[list[int]]) -> tuple[G1Point, list[FF]]: +def ecip_functions( + Bs: list[G1Point] | list[G2Point], dss: list[list[int]] +) -> tuple[G1Point | G2Point, list[FF]]: dss.reverse() - Q = G1Point(0, 0, Bs[0].curve_id) + ec_group_class = G1Point if isinstance(Bs[0], G1Point) else G2Point + Q = ec_group_class.infinity(Bs[0].curve_id) Ds = [] for ds in dss: D, Q = row_function(ds, Bs, Q) @@ -305,14 +472,16 @@ def dlog(d: FF) -> FunctionFelt: V=2*y*D: (6*x + 22)*y^2 + (2*x^2 + 2*x + 2)*y """ - field = get_base_field(d.curve_id.value) + field = d.field d: FF = d.reduce() assert len(d.coeffs) == 2, f"D has {len(d.coeffs)} coeffs: {d.coeffs}" Dx = FF([d[0].differentiate(), d[1].differentiate()], d.curve_id) Dy: Polynomial = d[1] # B(x) - TWO_Y: FF = FF([Polynomial.zero(field.p), Polynomial([field(2)])], d.curve_id) + TWO_Y: FF = FF( + [Polynomial.zero(field.p, d.type), Polynomial([field(2)])], d.curve_id + ) U: FF = Dx * TWO_Y + FF( [ Dy @@ -323,7 +492,7 @@ def dlog(d: FF) -> FunctionFelt: field(3), ] # 3x^2 + A ), - Polynomial.zero(field.p), + Polynomial.zero(field.p, d.type), ], d.curve_id, ) @@ -396,28 +565,57 @@ def print_ff(ff: FF): random.seed(0) - def build_cairo1_tests_derive_ec_point_from_X(x: int, curve_id: CurveID, idx: int): - x_f, y, roots = derive_ec_point_from_X(x, curve_id) - - code = f""" - #[test] - fn derive_ec_point_from_X_{CurveID(curve_id).name}_{idx}() {{ - let x: felt252 = {x%STARK}; - let y: u384 = {int_to_u384(y)}; - let grhs_roots:Array = {int_array_to_u384_array(roots)}; - let result = derive_ec_point_from_X(x, y, grhs_roots, {curve_id.value}); - assert!(result.x == {int_to_u384(x_f)}); - assert!(result.y == y); - }} - """ - return code - - codes = "\n".join( - [ - build_cairo1_tests_derive_ec_point_from_X(x, curve_id, idx) - for idx, x in enumerate([random.randint(0, STARK - 1) for _ in range(2)]) - for curve_id in CurveID - ] - ) - - print(codes) + # def build_cairo1_tests_derive_ec_point_from_X(x: int, curve_id: CurveID, idx: int): + # x_f, y, roots = derive_ec_point_from_X(x, curve_id) + + # code = f""" + # #[test] + # fn derive_ec_point_from_X_{CurveID(curve_id).name}_{idx}() {{ + # let x: felt252 = {x%STARK}; + # let y: u384 = {int_to_u384(y)}; + # let grhs_roots:Array = {int_array_to_u384_array(roots)}; + # let result = derive_ec_point_from_X(x, y, grhs_roots, {curve_id.value}); + # assert!(result.x == {int_to_u384(x_f)}); + # assert!(result.y == y); + # }} + # """ + # return code + + # codes = "\n".join( + # [ + # build_cairo1_tests_derive_ec_point_from_X(x, curve_id, idx) + # for idx, x in enumerate([random.randint(0, STARK - 1) for _ in range(2)]) + # for curve_id in CurveID + # ] + # ) + + # print(codes) + + # average_n_roots = 0 + # max_n_roots = 0 + # n = 10000 + # for i in range(n): + # x, y, roots = derive_ec_point_from_X( + # Fp2.random(CURVES[0].p, max_value=STARK), CurveID(0) + # ) + # # print(f"x: {x}, y: {y}, roots: {roots}") + # max_n_roots = max(max_n_roots, len(roots)) + # average_n_roots += len(roots) + # print(f"Average number of roots: {average_n_roots / n}") + # print(f"Max number of roots: {max_n_roots}") + + curve_index = 0 + order = CURVES[curve_index].n + n_points = 4 + # G1 + + Bs = [G1Point.gen_random_point(CurveID(curve_index)) for _ in range(n_points)] + scalars = [random.randint(1, order - 1) for _ in range(n_points)] + + verify_ecip(Bs, scalars) + print("g1 ok") + + # G2 + Bs = [G2Point.gen_random_point(CurveID(curve_index)) for _ in range(n_points)] + verify_ecip(Bs, scalars) + print("g2 ok") diff --git a/tools/gnark/bls12_381/cairo_test/main.go b/tools/gnark/bls12_381/cairo_test/main.go index bba1a9b1..fb1be117 100644 --- a/tools/gnark/bls12_381/cairo_test/main.go +++ b/tools/gnark/bls12_381/cairo_test/main.go @@ -106,6 +106,45 @@ func main() { z.Y.FromMont() fmt.Println(z) + case "g2": + var z, P, Q bls12381.G2Affine + var Px, Py, Qx, Qy fptower.E2 + n := new(big.Int) + n, _ = n.SetString(c.Args().Get(2), 10) + Px.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(3), 10) + Px.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(4), 10) + Py.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(5), 10) + Py.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(6), 10) + Qx.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(7), 10) + Qx.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(8), 10) + Qy.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(9), 10) + Qy.A1.SetBigInt(n) + + P.X = Px + P.Y = Py + Q.X = Qx + Q.Y = Qy + + switch c.Args().Get(1) { + case "add": + + z.Add(&P, &Q) + case "sub": + z.Sub(&P, &Q) + } + + z.X.A0.FromMont() + z.X.A1.FromMont() + z.Y.A0.FromMont() + z.Y.A1.FromMont() + fmt.Println(z) case "ng1": var z, x bls12381.G1Affine var A0, A1 fp.Element @@ -126,6 +165,32 @@ func main() { z.Y.FromMont() fmt.Println(z) + case "ng2": + var P, Q bls12381.G2Affine + var X, Y fptower.E2 + k := new(big.Int) + n := new(big.Int) + n, _ = n.SetString(c.Args().Get(1), 10) + X.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(2), 10) + X.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(3), 10) + Y.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(4), 10) + Y.A1.SetBigInt(n) + k.SetString(c.Args().Get(5), 10) + + P.X = X + P.Y = Y + + Q.ScalarMultiplication(&P, k) + + Q.X.A0.FromMont() + Q.X.A1.FromMont() + Q.Y.A0.FromMont() + Q.Y.A1.FromMont() + fmt.Println(Q) + case "nG1nG2": var P1 bls12381.G1Affine var P2 bls12381.G2Affine diff --git a/tools/gnark/main.go b/tools/gnark/main.go index 3079fd2a..177c697a 100644 --- a/tools/gnark/main.go +++ b/tools/gnark/main.go @@ -103,6 +103,45 @@ func main() { z.X.FromMont() z.Y.FromMont() fmt.Println(z) + case "g2": + var z, P, Q bn254.G2Affine + var Px, Py, Qx, Qy fptower.E2 + n := new(big.Int) + n, _ = n.SetString(c.Args().Get(2), 10) + Px.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(3), 10) + Px.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(4), 10) + Py.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(5), 10) + Py.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(6), 10) + Qx.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(7), 10) + Qx.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(8), 10) + Qy.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(9), 10) + Qy.A1.SetBigInt(n) + + P.X = Px + P.Y = Py + Q.X = Qx + Q.Y = Qy + + switch c.Args().Get(1) { + case "add": + + z.Add(&P, &Q) + case "sub": + z.Sub(&P, &Q) + } + + z.X.A0.FromMont() + z.X.A1.FromMont() + z.Y.A0.FromMont() + z.Y.A1.FromMont() + fmt.Println(z) case "ng1": var z, x bn254.G1Affine @@ -123,6 +162,32 @@ func main() { z.X.FromMont() z.Y.FromMont() fmt.Println(z) + case "ng2": + var Q, P bn254.G2Affine + var X, Y fptower.E2 + n := new(big.Int) + k := new(big.Int) + n, _ = n.SetString(c.Args().Get(1), 10) + X.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(2), 10) + X.A1.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(3), 10) + Y.A0.SetBigInt(n) + n, _ = n.SetString(c.Args().Get(4), 10) + Y.A1.SetBigInt(n) + + k.SetString(c.Args().Get(5), 10) + + P.X = X + P.Y = Y + + Q.ScalarMultiplication(&P, k) + + Q.X.A0.FromMont() + Q.X.A1.FromMont() + Q.Y.A0.FromMont() + Q.Y.A1.FromMont() + fmt.Println(Q) case "nG1nG2": var P1 bn254.G1Affine diff --git a/tools/gnark_cli.py b/tools/gnark_cli.py index 6f2b6387..3d179711 100644 --- a/tools/gnark_cli.py +++ b/tools/gnark_cli.py @@ -78,6 +78,42 @@ def g1_scalarmul(self, p1: tuple[int, int], n: int): assert len(res) == 2, f"Got {output}" return (res[0], res[1]) + def g2_add( + self, + p1: tuple[tuple[int, int], tuple[int, int]], + p2: tuple[tuple[int, int], tuple[int, int]], + ): + args = [ + "g2", + "add", + str(p1[0][0]), + str(p1[0][1]), + str(p1[1][0]), + str(p1[1][1]), + str(p2[0][0]), + str(p2[0][1]), + str(p2[1][0]), + str(p2[1][1]), + ] + output = self.run_command(args) + res = self.parse_fp_elements(output) + assert len(res) == 4, f"Got {output}" + return (res[0], res[1], res[2], res[3]) + + def g2_scalarmul(self, p1: tuple[tuple[int, int], tuple[int, int]], n: int): + args = [ + "ng2", + str(p1[0][0]), + str(p1[0][1]), + str(p1[1][0]), + str(p1[1][1]), + str(n), + ] + output = self.run_command(args) + res = self.parse_fp_elements(output) + assert len(res) == 4, f"Got {output}" + return (res[0], res[1], res[2], res[3]) + def nG1nG2_operation( self, n1: int, n2: int, raw: bool = False ) -> tuple[G1Point, G2Point] | list[int]: