Skip to content

Commit

Permalink
Optimize Coset::at
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Oct 24, 2024
1 parent 72c02f8 commit bb86b10
Show file tree
Hide file tree
Showing 7 changed files with 568 additions and 211 deletions.
16 changes: 0 additions & 16 deletions stwo_cairo_verifier/Scarb.lock
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@
# Code generated by scarb DO NOT EDIT.
version = 1

[[package]]
name = "snforge_scarb_plugin"
version = "0.32.0"
source = "git+https://github.com/foundry-rs/starknet-foundry?tag=v0.32.0#3817c903b640201c72e743b9bbe70a97149828a2"

[[package]]
name = "snforge_std"
version = "0.32.0"
source = "git+https://github.com/foundry-rs/starknet-foundry?tag=v0.32.0#3817c903b640201c72e743b9bbe70a97149828a2"
dependencies = [
"snforge_scarb_plugin",
]

[[package]]
name = "stwo_cairo_verifier"
version = "0.1.0"
dependencies = [
"snforge_std",
]
160 changes: 62 additions & 98 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ use core::num::traits::one::One;
use core::num::traits::zero::Zero;
use core::num::traits::{WrappingAdd, WrappingSub, WrappingMul};
use stwo_cairo_verifier::channel::{Channel, ChannelImpl};
use stwo_cairo_verifier::circle_mul_table::{
M31_CIRCLE_GEN_MUL_TABLE_BITS_24_TO_29, M31_CIRCLE_GEN_MUL_TABLE_BITS_18_TO_23,
M31_CIRCLE_GEN_MUL_TABLE_BITS_12_TO_17, M31_CIRCLE_GEN_MUL_TABLE_BITS_6_TO_11,
M31_CIRCLE_GEN_MUL_TABLE_BITS_0_TO_5
};
use stwo_cairo_verifier::fields::cm31::CM31;
use stwo_cairo_verifier::fields::m31::{M31, M31Impl};
use stwo_cairo_verifier::fields::qm31::{QM31Impl, QM31One, QM31, QM31Trait};
Expand Down Expand Up @@ -42,9 +47,7 @@ pub struct CirclePoint<F> {
pub y: F
}

pub trait CirclePointTrait<
F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>, +Zero<F>, +One<F>, +PartialEq<F>
> {
pub trait CirclePointTrait<F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>, +Zero<F>, +One<F>> {
// Returns the neutral element of the circle.
fn zero() -> CirclePoint<F> {
CirclePoint { x: One::one(), y: Zero::zero() }
Expand All @@ -55,42 +58,6 @@ pub trait CirclePointTrait<
let sqx = x * x;
sqx + sqx - One::one()
}

/// Returns the log order of a point.
///
/// All points have an order of the form `2^k`.
fn log_order(
self: @CirclePoint<F>
) -> u32 {
// we only need the x-coordinate to check order since the only point
// with x=1 is the circle's identity
let mut res = 0;
let mut cur = self.x.clone();
while cur != One::one() {
cur = Self::double_x(cur);
res += 1;
};
res
}

fn mul(
self: @CirclePoint<F>, scalar: u128
) -> CirclePoint<
F
> {
// TODO: `mut scalar: u128` doesn't work in trait.
let mut scalar = scalar;
let mut result = Self::zero();
let mut cur = *self;
while scalar != 0 {
if scalar & 1 == 1 {
result = result + cur;
}
cur = cur + cur;
scalar /= 2;
};
result
}
}

impl CirclePointAdd<F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>> of Add<CirclePoint<F>> {
Expand Down Expand Up @@ -241,9 +208,43 @@ pub impl CirclePointIndexImpl of CirclePointIndexTrait {
self.reduce().index
}

// Context: After running the initial verifier this function accounted for 50% of the steps.
// Most of the calls are made during the FRI folding. With these changes this function accounts
// for ~4% of total steps.
fn to_point(self: @CirclePointIndex) -> CirclePoint<M31> {
const NZ_2_POW_24: NonZero<u32> = 0b1000000000000000000000000;
const NZ_2_POW_18: NonZero<u32> = 0b1000000000000000000;
const NZ_2_POW_12: NonZero<u32> = 0b1000000000000;
const NZ_2_POW_6: NonZero<u32> = 0b1000000;

// No need to call `reduce()`.
M31_CIRCLE_GEN.mul((*self.index).into())
// Start with MSBs since small domains (more common) have LSBs equal 0.
let (bits_24_to_31, bits_0_to_23) = DivRem::div_rem(*self.index, NZ_2_POW_24);
let (bits_30_to_31, bits_24_to_29) = DivRem::div_rem(bits_24_to_31, NZ_2_POW_6);
let mut res = *M31_CIRCLE_GEN_MUL_TABLE_BITS_24_TO_29.span()[bits_24_to_29];
if bits_0_to_23 != 0 {
let (bits_18_to_23, bits_0_to_17) = DivRem::div_rem(bits_0_to_23, NZ_2_POW_18);
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_18_TO_23.span()[bits_18_to_23];
if bits_0_to_17 != 0 {
let (bits_12_to_17, bits_0_to_11) = DivRem::div_rem(bits_0_to_17, NZ_2_POW_12);
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_12_TO_17.span()[bits_12_to_17];
if bits_0_to_11 != 0 {
let (bits_6_to_11, bits_0_to_5) = DivRem::div_rem(bits_0_to_11, NZ_2_POW_6);
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_6_TO_11.span()[bits_6_to_11];
if bits_0_to_5 != 0 {
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_0_TO_5.span()[bits_0_to_5];
}
}
}
}

// Note this applies the appropriate transformation based on the two highest bits.
// The highest bit has no effect. The 30th bit indicates weather to take the antipode.
if bits_30_to_31 == 0b11 || bits_30_to_31 == 0b01 {
res = CirclePoint { x: -res.x, y: -res.y };
}

res
}
}

Expand Down Expand Up @@ -274,13 +275,30 @@ impl CirclePointIndexPartialEx of PartialEq<CirclePointIndex> {
#[cfg(test)]
mod tests {
use stwo_cairo_verifier::fields::m31::m31;
use stwo_cairo_verifier::fields::qm31::{QM31One, qm31};
use stwo_cairo_verifier::utils::pow;
use stwo_cairo_verifier::fields::qm31::QM31One;
use super::{
M31_CIRCLE_GEN, CirclePointQM31Impl, QM31_CIRCLE_GEN, M31_CIRCLE_ORDER, CirclePoint,
CirclePointM31Impl, CirclePointIndexImpl, Coset, CosetImpl, QM31_CIRCLE_ORDER
M31_CIRCLE_GEN, CirclePointQM31Impl, CirclePoint, CirclePointM31Impl, CirclePointIndex,
CirclePointIndexImpl, Coset, CosetImpl
};


#[test]
fn test_to_point() {
let index = CirclePointIndex { index: 0b01111111111111111111111111111111 };
assert_eq!(index.to_point(), -M31_CIRCLE_GEN);
let index = CirclePointIndex { index: 0b00111111111111111111111111111111 };
assert_eq!(index.to_point(), CirclePoint { x: -M31_CIRCLE_GEN.x, y: M31_CIRCLE_GEN.y });
}


#[test]
fn test_to_point_with_unreduced_index() {
// All 32 bits are `1`.
let index = CirclePointIndex { index: 0b11111111111111111111111111111111 };

assert_eq!(index.to_point(), -M31_CIRCLE_GEN);
}

#[test]
fn test_add_1() {
let g4 = CirclePoint { x: m31(0), y: m31(1) };
Expand Down Expand Up @@ -314,52 +332,6 @@ mod tests {
assert_eq!(result, point_1.clone());
}

#[test]
fn test_mul_1() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(5);

assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1);
}

#[test]
fn test_mul_2() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(8);

assert_eq!(
result, point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1
);
}

#[test]
fn test_mul_3() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(418776494);

assert_eq!(result, CirclePoint { x: m31(1987283985), y: m31(1500510905) });
}

#[test]
fn test_generator_order() {
let half_order = M31_CIRCLE_ORDER / 2;

let mut result = M31_CIRCLE_GEN.mul(half_order.into());

// Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`.
assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_generator() {
let mut result = M31_CIRCLE_GEN.mul(pow(2, 30).into());

assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_coset_index_at() {
let coset = Coset {
Expand Down Expand Up @@ -432,13 +404,5 @@ mod tests {

assert_eq!(result, 32);
}

#[test]
fn test_qm31_circle_gen() {
assert_eq!(
QM31_CIRCLE_GEN.mul(QM31_CIRCLE_ORDER / 2),
CirclePoint { x: -qm31(1, 0, 0, 0), y: qm31(0, 0, 0, 0) }
);
}
}

Loading

0 comments on commit bb86b10

Please sign in to comment.