Skip to content

Commit

Permalink
Merge branch 'main' into feat/added_check_to_from_hex_unchecked
Browse files Browse the repository at this point in the history
  • Loading branch information
MauroToscano authored Jul 31, 2023
2 parents 127d48d + 81d8311 commit 23fcebc
Show file tree
Hide file tree
Showing 17 changed files with 4,420 additions and 1,183 deletions.
1 change: 1 addition & 0 deletions crypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ lambdaworks-math.workspace = true
sha3 = "0.10"
sha2 = "0.10"
thiserror = "1.0.38"
serde = { version = "1.0", features = ["derive"] }

[dev-dependencies]
criterion = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/merkle_tree/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::traits::IsMerkleTreeBackend;
/// `merkle_path` field, in such a way that, if the merkle tree is of height `n`, the
/// `i`-th element of `merkle_path` is the sibling node in the `n - 1 - i`-th check
/// when verifying.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Proof<T: PartialEq + Eq> {
pub merkle_path: Vec<T>,
}
Expand Down
3 changes: 3 additions & 0 deletions math/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ license.workspace = true

[dependencies]
thiserror = { version = "1.0", optional = true }
serde = { version = "1.0", features = ["derive"], optional = true }
serde_json = { version = "1.0", optional = true }

# rayon
rayon = { version = "1.7", optional = true }
Expand All @@ -33,6 +35,7 @@ iai-callgrind.workspace = true
rayon = ["dep:rayon"]
default = ["rayon", "std"]
std = ["dep:thiserror"]
lambdaworks-serde = ["dep:serde", "dep:serde_json", "std"]

# gpu
metal = [
Expand Down
1 change: 1 addition & 0 deletions math/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub enum ByteConversionError {
#[derive(Debug, PartialEq, Eq)]
pub enum CreationError {
InvalidHexString,
InvalidDecString,
}

#[derive(Debug, PartialEq, Eq)]
Expand Down
29 changes: 24 additions & 5 deletions math/src/fft/gpu/cuda/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ where
{
let mut function = state.get_radix2_dit_butterfly(input, twiddles)?;

const WARP_SIZE: usize = 32;

let block_size = WARP_SIZE;
let butterfly_count = input.len() / 2;
let block_count = (butterfly_count + block_size - 1) / block_size;

let order = input.len().trailing_zeros();
for stage in 0..order {
let group_count = 1 << stage;
let group_size = input.len() / group_count;

function.launch(group_count, group_size)?;
for stage in 0..order {
function.launch(block_count, block_size, stage, butterfly_count as u32)?;
}

let output = function.retrieve_result()?;
Expand Down Expand Up @@ -67,7 +71,7 @@ pub fn bitrev_permutation<F: IsFFTField>(
) -> Result<Vec<FieldElement<F>>, CudaError> {
let mut function = state.get_bitrev_permutation(&input, &input)?;

function.launch(input.len())?;
function.launch()?;

function.retrieve_result()
}
Expand Down Expand Up @@ -116,6 +120,21 @@ mod tests {
}
}

#[test]
fn test_cuda_fft_matches_sequential_large_input() {
const ORDER: usize = 20;
let input = vec![FE::one(); 1 << ORDER];

let state = CudaState::new().unwrap();
let order = input.len().trailing_zeros();
let twiddles = get_twiddles(order.into(), RootsConfig::BitReverse).unwrap();

let cuda_result = fft(&input, &twiddles, &state).unwrap();
let sequential_result = crate::fft::cpu::ops::fft(&input, &twiddles).unwrap();

assert_eq!(&cuda_result, &sequential_result);
}

#[test]
fn gen_twiddles_with_order_greater_than_63_should_fail() {
let state = CudaState::new().unwrap();
Expand Down
70 changes: 25 additions & 45 deletions math/src/fft/gpu/cuda/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use lambdaworks_gpu::cuda::abstractions::errors::CudaError;
use std::sync::Arc;

const STARK256_PTX: &str = include_str!("../../../gpu/cuda/shaders/field/stark256.ptx");
const WARP_SIZE: usize = 32; // the implementation will spawn threadblocks of this size.

/// Structure for abstracting basic calls to a CUDA device and saving the state. Used for
/// implementing GPU parallel computations in CUDA.
Expand Down Expand Up @@ -166,23 +167,13 @@ impl<F: IsField> Radix2DitButterflyFunction<F> {

pub(crate) fn launch(
&mut self,
group_count: usize,
group_size: usize,
block_count: usize,
block_size: usize,
stage: u32,
butterfly_count: u32,
) -> Result<(), CudaError> {
let grid_dim = (group_count as u32, 1, 1); // in blocks
let block_dim = ((group_size / 2) as u32, 1, 1);

if block_dim.0 as usize > DeviceSlice::len(&self.twiddles) {
return Err(CudaError::IndexOutOfBounds(
block_dim.0 as usize,
self.twiddles.len(),
));
} else if (grid_dim.0 * block_dim.0) as usize > DeviceSlice::len(&self.input) {
return Err(CudaError::IndexOutOfBounds(
(grid_dim.0 * block_dim.0) as usize,
self.input.len(),
));
}
let grid_dim = (block_count as u32, 1, 1); // in blocks
let block_dim = (block_size as u32, 1, 1);

let config = LaunchConfig {
grid_dim,
Expand All @@ -193,9 +184,10 @@ impl<F: IsField> Radix2DitButterflyFunction<F> {
// Calling a kernel is similar to calling a foreign-language function,
// as the kernel itself could be written in C or unsafe Rust.
unsafe {
self.function
.clone()
.launch(config, (&mut self.input, &self.twiddles))
self.function.clone().launch(
config,
(&mut self.input, &self.twiddles, stage, butterfly_count),
)
}
.map_err(|err| CudaError::Launch(err.to_string()))
}
Expand Down Expand Up @@ -235,16 +227,12 @@ impl<F: IsField> CalcTwiddlesFunction<F> {
}
}

pub(crate) fn launch(&mut self, group_size: usize) -> Result<(), CudaError> {
let grid_dim = (1, 1, 1); // in blocks
let block_dim = (group_size as u32, 1, 1);
pub(crate) fn launch(&mut self, count: usize) -> Result<(), CudaError> {
let block_size = WARP_SIZE;
let block_count = (count + block_size - 1) / block_size;

if block_dim.0 as usize > DeviceSlice::len(&self.twiddles) {
return Err(CudaError::IndexOutOfBounds(
block_dim.0 as usize,
self.twiddles.len(),
));
}
let grid_dim = (block_count as u32, 1, 1); // in blocks
let block_dim = (block_size as u32, 1, 1);

let config = LaunchConfig {
grid_dim,
Expand All @@ -257,7 +245,7 @@ impl<F: IsField> CalcTwiddlesFunction<F> {
unsafe {
self.function
.clone()
.launch(config, (&mut self.twiddles, &self.omega))
.launch(config, (&mut self.twiddles, &self.omega, count as u32))
}
.map_err(|err| CudaError::Launch(err.to_string()))
}
Expand Down Expand Up @@ -299,21 +287,13 @@ impl<F: IsField> BitrevPermutationFunction<F> {
}
}

pub(crate) fn launch(&mut self, group_size: usize) -> Result<(), CudaError> {
let grid_dim = (1, 1, 1); // in blocks
let block_dim = (group_size as u32, 1, 1);

if block_dim.0 as usize > DeviceSlice::len(&self.input) {
return Err(CudaError::IndexOutOfBounds(
block_dim.0 as usize,
self.input.len(),
));
} else if block_dim.0 as usize > DeviceSlice::len(&self.result) {
return Err(CudaError::IndexOutOfBounds(
block_dim.0 as usize,
self.result.len(),
));
}
pub(crate) fn launch(&mut self) -> Result<(), CudaError> {
let len = self.input.len();
let block_size = WARP_SIZE;
let block_count = (len + block_size - 1) / block_size;

let grid_dim = (block_count as u32, 1, 1); // in blocks
let block_dim = (block_size as u32, 1, 1);

let config = LaunchConfig {
grid_dim,
Expand All @@ -326,7 +306,7 @@ impl<F: IsField> BitrevPermutationFunction<F> {
unsafe {
self.function
.clone()
.launch(config, (&mut self.input, &self.result))
.launch(config, (&mut self.input, &self.result, len))
}
.map_err(|err| CudaError::Launch(err.to_string()))
}
Expand Down
29 changes: 23 additions & 6 deletions math/src/fft/gpu/metal/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,18 @@ pub fn fft<F: IsFFTField>(
objc::rc::autoreleasepool(|| {
let (command_buffer, command_encoder) = state.setup_command(
&pipeline,
Some(&[(0, &input_buffer), (1, &twiddles_buffer)]),
Some(&[(0, &input_buffer), (1, &twiddles_buffer)]), // index 2 is stage
);

let order = input.len().trailing_zeros();
for stage in 0..order {
let group_count = 1 << stage;
let group_size = input.len() as u64 / group_count;
command_encoder.set_bytes(2, mem::size_of_val(&stage) as u64, void_ptr(&stage));

let threadgroup_size = MTLSize::new(group_size / 2, 1, 1);
let threadgroup_count = MTLSize::new(group_count, 1, 1);
command_encoder.dispatch_thread_groups(threadgroup_count, threadgroup_size);
let grid_size = MTLSize::new(input.len() as u64 / 2, 1, 1); // one thread per butterfly
let threadgroup_size = MTLSize::new(pipeline.thread_execution_width(), 1, 1);

// WARN: Device should support non-uniform threadgroups (Metal3 and Apple4 or latter).
command_encoder.dispatch_threads(grid_size, threadgroup_size);
}
command_encoder.end_encoding();

Expand Down Expand Up @@ -181,6 +182,22 @@ mod tests {
}
}

// May want to modify the order constant, takes ~5s to run on a M1.
#[test]
fn test_metal_fft_matches_sequential_large_input() {
const ORDER: usize = 20;
let input = vec![FE::one(); 1 << ORDER];

let metal_state = MetalState::new(None).unwrap();
let order = input.len().trailing_zeros();
let twiddles = get_twiddles(order.into(), RootsConfig::BitReverse).unwrap();

let metal_result = super::fft(&input, &twiddles, &metal_state).unwrap();
let sequential_result = crate::fft::cpu::ops::fft(&input, &twiddles).unwrap();

assert_eq!(&metal_result, &sequential_result);
}

#[test]
fn gen_twiddles_with_order_greater_than_63_should_fail() {
let metal_state = MetalState::new(None).unwrap();
Expand Down
66 changes: 66 additions & 0 deletions math/src/field/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ use crate::unsigned_integer::traits::IsUnsignedInteger;
use core::fmt;
use core::fmt::Debug;
use core::iter::Sum;
#[cfg(feature = "lambdaworks-serde")]
use core::marker::PhantomData;
use core::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
#[cfg(feature = "lambdaworks-serde")]
use serde::de::{self, Deserializer, MapAccess, Visitor};
#[cfg(feature = "lambdaworks-serde")]
use serde::ser::{Serialize, SerializeStruct, Serializer};
#[cfg(feature = "lambdaworks-serde")]
use serde::Deserialize;

use super::fields::montgomery_backed_prime_fields::{IsModulus, MontgomeryBackendPrimeField};
use super::traits::{IsPrimeField, LegendreSymbol};
Expand Down Expand Up @@ -421,6 +429,64 @@ impl<F: IsPrimeField> FieldElement<F> {
}
}

#[cfg(feature = "lambdaworks-serde")]
impl<F: IsPrimeField> Serialize for FieldElement<F> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("FieldElement", 1)?;
state.serialize_field("value", &F::representative(self.value()).to_string())?;
state.end()
}
}

#[cfg(feature = "lambdaworks-serde")]
impl<'de, F: IsPrimeField> Deserialize<'de> for FieldElement<F> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Value,
}

struct FieldElementVisitor<F>(PhantomData<fn() -> F>);

impl<'de, F: IsPrimeField> Visitor<'de> for FieldElementVisitor<F> {
type Value = FieldElement<F>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct FieldElement")
}

fn visit_map<M>(self, mut map: M) -> Result<FieldElement<F>, M::Error>
where
M: MapAccess<'de>,
{
let mut value = None;
while let Some(key) = map.next_key()? {
match key {
Field::Value => {
if value.is_some() {
return Err(de::Error::duplicate_field("value"));
}
value = Some(map.next_value()?);
}
}
}
let value = value.ok_or_else(|| de::Error::missing_field("value"))?;
Ok(FieldElement::from_hex(value).unwrap())
}
}

const FIELDS: &[&str] = &["value"];
deserializer.deserialize_struct("FieldElement", FIELDS, FieldElementVisitor(PhantomData))
}
}

impl<M, const NUM_LIMBS: usize> fmt::Display
for FieldElement<MontgomeryBackendPrimeField<M, NUM_LIMBS>>
where
Expand Down
14 changes: 14 additions & 0 deletions math/src/field/fields/montgomery_backed_prime_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub trait IsModulus<U>: Debug {
const MODULUS: U;
}

#[cfg_attr(
feature = "lambdaworks-serde",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Clone, Debug, Hash, Copy)]
pub struct MontgomeryBackendPrimeField<M, const NUM_LIMBS: usize> {
phantom: PhantomData<M>,
Expand Down Expand Up @@ -427,6 +431,16 @@ mod tests_u384_prime_fields {
assert_eq!(x * y, c);
}

#[test]
#[cfg(feature = "lambdaworks-serde")]
fn montgomery_backend_serialization_deserialization() {
let x = U384F23Element::from(11_u64);
let x_serialized = serde_json::to_string(&x).unwrap();
let x_deserialized: U384F23Element = serde_json::from_str(&x_serialized).unwrap();
assert_eq!(x_serialized, "{\"value\":\"0xb\"}");
assert_eq!(x_deserialized, x);
}

const ORDER: usize = 23;
#[test]
fn two_plus_one_is_three() {
Expand Down
9 changes: 5 additions & 4 deletions math/src/gpu/cuda/shaders/fft/bitrev_permutation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#include "../utils.h"

template <class Fp>
inline __device__ void _bitrev_permutation(const Fp *input, Fp *result)
inline __device__ void _bitrev_permutation(const Fp *input, Fp *result, const int len)
{
unsigned index = threadIdx.x;
unsigned size = blockDim.x;
unsigned thread_pos = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_pos >= len) return;
// TODO: guard is not needed for inputs of len >=block_size * 2, if len is pow of two

result[index] = input[reverse_index(index, size)];
result[thread_pos] = input[reverse_index(thread_pos, len)];
};
Loading

0 comments on commit 23fcebc

Please sign in to comment.