Skip to content

Commit

Permalink
refactor: transpose_column functions are moved to a seperate class
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtrombetta committed Jun 12, 2024
1 parent e85945d commit 34aea2b
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 333 deletions.
Original file line number Diff line number Diff line change
@@ -1,148 +1,12 @@
use super::{pairings, DoryCommitment, DoryProverPublicSetup, DoryScalar, G1Affine};
use super::{pairings, transpose, DoryCommitment, DoryProverPublicSetup, DoryScalar, G1Affine};
use crate::base::commitment::CommittableColumn;
use ark_bls12_381::Fr;
use ark_ec::CurveGroup;
use ark_std::ops::Mul;
use blitzar::{compute::ElementP2, sequence::Sequence};
use num_traits::ToBytes;
use rayon::prelude::*;
use zerocopy::AsBytes;

trait OffsetToBytes {
const IS_SIGNED: bool;
fn min_as_fr() -> Fr;
fn offset_to_bytes(&self) -> Vec<u8>;
}

impl OffsetToBytes for u8 {
const IS_SIGNED: bool = false;

fn min_as_fr() -> Fr {
Fr::from(0)
}

fn offset_to_bytes(&self) -> Vec<u8> {
vec![*self]
}
}

impl OffsetToBytes for i16 {
const IS_SIGNED: bool = true;

fn min_as_fr() -> Fr {
Fr::from(i16::MIN)
}

fn offset_to_bytes(&self) -> Vec<u8> {
let shifted = self.wrapping_sub(i16::MIN);
shifted.to_le_bytes().to_vec()
}
}

impl OffsetToBytes for i32 {
const IS_SIGNED: bool = true;

fn min_as_fr() -> Fr {
Fr::from(i32::MIN)
}

fn offset_to_bytes(&self) -> Vec<u8> {
let shifted = self.wrapping_sub(i32::MIN);
shifted.to_le_bytes().to_vec()
}
}

impl OffsetToBytes for i64 {
const IS_SIGNED: bool = true;

fn min_as_fr() -> Fr {
Fr::from(i64::MIN)
}

fn offset_to_bytes(&self) -> Vec<u8> {
let shifted = self.wrapping_sub(i64::MIN);
shifted.to_le_bytes().to_vec()
}
}

impl OffsetToBytes for i128 {
const IS_SIGNED: bool = true;

fn min_as_fr() -> Fr {
Fr::from(i128::MIN)
}

fn offset_to_bytes(&self) -> Vec<u8> {
let shifted = self.wrapping_sub(i128::MIN);
shifted.to_le_bytes().to_vec()
}
}

impl OffsetToBytes for bool {
const IS_SIGNED: bool = false;

fn min_as_fr() -> Fr {
Fr::from(false)
}

fn offset_to_bytes(&self) -> Vec<u8> {
vec![*self as u8]
}
}

impl OffsetToBytes for u64 {
const IS_SIGNED: bool = false;

fn min_as_fr() -> Fr {
Fr::from(0)
}

fn offset_to_bytes(&self) -> Vec<u8> {
let bytes = self.to_le_bytes();
bytes.to_vec()
}
}

impl OffsetToBytes for [u64; 4] {
const IS_SIGNED: bool = false;

fn min_as_fr() -> Fr {
Fr::from(0)
}

fn offset_to_bytes(&self) -> Vec<u8> {
let slice = self.as_bytes();
slice.to_vec()
}
}

#[tracing::instrument(name = "transpose_column (gpu)", level = "debug", skip_all)]
fn transpose_column<T: AsBytes + Copy + OffsetToBytes>(
column: &[T],
offset: usize,
num_columns: usize,
data_size: usize,
) -> Vec<u8> {
let column_len_with_offset = column.len() + offset;
let total_length_bytes =
data_size * (((column_len_with_offset + num_columns - 1) / num_columns) * num_columns);
let cols = num_columns;
let rows = total_length_bytes / (data_size * cols);

let mut transpose = vec![0_u8; total_length_bytes];
for n in offset..(column.len() + offset) {
let i = n / cols;
let j = n % cols;
let t_idx = (j * rows + i) * data_size;
let p_idx = (i * cols + j) - offset;

transpose[t_idx..t_idx + data_size]
.copy_from_slice(column[p_idx].offset_to_bytes().as_slice());
}

transpose
}

#[tracing::instrument(name = "get_offset_commits (gpu)", level = "debug", skip_all)]
fn get_offset_commits(
column_len: usize,
Expand All @@ -167,7 +31,7 @@ fn get_offset_commits(
// Get the commit of the first non-zero row
let first_row_offset = offset - (num_zero_commits * num_columns);
let first_row_transpose =
transpose_column(first_row, first_row_offset, num_columns, data_size);
transpose::transpose_for_fixed_msm(first_row, first_row_offset, num_columns, data_size);

setup.public_parameters().blitzar_handle.msm(
&mut ones_blitzar_commits[num_zero_commits..num_zero_commits + 1],
Expand All @@ -179,7 +43,8 @@ fn get_offset_commits(
let mut chunks = remaining_elements.chunks(num_columns);
if chunks.len() > 1 {
if let Some(middle_row) = chunks.next() {
let middle_row_transpose = transpose_column(middle_row, 0, num_columns, data_size);
let middle_row_transpose =
transpose::transpose_for_fixed_msm(middle_row, 0, num_columns, data_size);
let mut middle_row_blitzar_commit =
vec![ElementP2::<ark_bls12_381::g1::Config>::default(); 1];

Expand All @@ -197,7 +62,8 @@ fn get_offset_commits(

// Get the commit of the last row to handle an zero padding at the end of the column
if let Some(last_row) = remaining_elements.chunks(num_columns).last() {
let last_row_transpose = transpose_column(last_row, 0, num_columns, data_size);
let last_row_transpose =
transpose::transpose_for_fixed_msm(last_row, 0, num_columns, data_size);

setup.public_parameters().blitzar_handle.msm(
&mut ones_blitzar_commits[num_of_commits - 1..num_of_commits],
Expand All @@ -223,13 +89,14 @@ fn compute_dory_commitment_impl<'a, T>(
where
&'a T: Into<DoryScalar>,
&'a [T]: Into<Sequence<'a>>,
T: AsBytes + Copy + OffsetToBytes,
T: AsBytes + Copy + transpose::OffsetToBytes,
{
let num_columns = 1 << setup.sigma();
let data_size = std::mem::size_of::<T>();

// Format column to match column major data layout required by blitzar's msm
let column_transpose = transpose_column(column, offset, num_columns, data_size);
let column_transpose =
transpose::transpose_for_fixed_msm(column, offset, num_columns, data_size);
let num_of_commits = column_transpose.len() / (data_size * num_columns);
let gamma_2_slice = &setup.public_parameters().Gamma_2[0..num_of_commits];

Expand Down Expand Up @@ -293,194 +160,3 @@ pub(super) fn compute_dory_commitments(
.map(|column| compute_dory_commitment(column, offset, setup))
.collect()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn we_can_transpose_empty_column() {
type T = u64;
let column: Vec<T> = vec![];
let offset = 0;
let num_columns = 2;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);
assert!(transpose.is_empty());
}

#[test]
fn we_can_transpose_u64_column() {
type T = u64;
let column: Vec<T> = vec![0, 1, 2, 3];
let offset = 0;
let num_columns = 2;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);

assert_eq!(&transpose[0..data_size], column[0].as_bytes());
assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes());
assert_eq!(
&transpose[2 * data_size..3 * data_size],
column[1].as_bytes()
);
assert_eq!(
&transpose[3 * data_size..4 * data_size],
column[3].as_bytes()
);
}

#[test]
fn we_can_transpose_u64_column_with_offset() {
type T = u64;
let column: Vec<T> = vec![1, 2, 3];
let offset = 2;
let num_columns = 3;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset + 1);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);

assert_eq!(&transpose[0..data_size], 0_u64.as_bytes());
assert_eq!(&transpose[data_size..2 * data_size], column[1].as_bytes());
assert_eq!(&transpose[2 * data_size..3 * data_size], 0_u64.as_bytes());
assert_eq!(
&transpose[3 * data_size..4 * data_size],
column[2].as_bytes()
);
assert_eq!(
&transpose[4 * data_size..5 * data_size],
column[0].as_bytes()
);
assert_eq!(&transpose[5 * data_size..6 * data_size], 0_u64.as_bytes());
}

#[test]
fn we_can_transpose_boolean_column_with_offset() {
type T = bool;
let column: Vec<T> = vec![true, false, true];
let offset = 1;
let num_columns = 2;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);

assert_eq!(&transpose[0..data_size], 0_u8.as_bytes());
assert_eq!(&transpose[data_size..2 * data_size], column[1].as_bytes());
assert_eq!(
&transpose[2 * data_size..3 * data_size],
column[0].as_bytes()
);
assert_eq!(
&transpose[3 * data_size..4 * data_size],
column[2].as_bytes()
);
}

#[test]
fn we_can_transpose_i64_column() {
type T = i64;
let column: Vec<T> = vec![0, 1, 2, 3];
let offset = 0;
let num_columns = 2;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);

assert_eq!(
&transpose[0..data_size],
column[0].wrapping_sub(T::MIN).as_bytes()
);
assert_eq!(
&transpose[data_size..2 * data_size],
column[2].wrapping_sub(T::MIN).as_bytes()
);
assert_eq!(
&transpose[2 * data_size..3 * data_size],
column[1].wrapping_sub(T::MIN).as_bytes()
);
assert_eq!(
&transpose[3 * data_size..4 * data_size],
column[3].wrapping_sub(T::MIN).as_bytes()
);
}

#[test]
fn we_can_transpose_i128_column() {
type T = i128;
let column: Vec<T> = vec![0, 1, 2, 3];
let offset = 0;
let num_columns = 2;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);

assert_eq!(
&transpose[0..data_size],
column[0].wrapping_sub(T::MIN).as_bytes()
);
assert_eq!(
&transpose[data_size..2 * data_size],
column[2].wrapping_sub(T::MIN).as_bytes()
);
assert_eq!(
&transpose[2 * data_size..3 * data_size],
column[1].wrapping_sub(T::MIN).as_bytes()
);
assert_eq!(
&transpose[3 * data_size..4 * data_size],
column[3].wrapping_sub(T::MIN).as_bytes()
);
}

#[test]
fn we_can_transpose_u64_array_column() {
type T = [u64; 4];
let column: Vec<T> = vec![[0, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]];
let offset = 0;
let num_columns = 2;
let data_size = std::mem::size_of::<T>();

let expected_len = data_size * (column.len() + offset);

let transpose = transpose_column(&column, offset, num_columns, data_size);

assert_eq!(transpose.len(), expected_len);

assert_eq!(&transpose[0..data_size], column[0].as_bytes());
assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes());
assert_eq!(
&transpose[2 * data_size..3 * data_size],
column[1].as_bytes()
);
assert_eq!(
&transpose[3 * data_size..4 * data_size],
column[3].as_bytes()
);
}
}
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/proof_primitive/dory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ type DeferredG1 = deferred_msm::DeferredMSM<G1Affine, F>;
type DeferredG2 = deferred_msm::DeferredMSM<G2Affine, F>;

mod pairings;
mod transpose;
Loading

0 comments on commit 34aea2b

Please sign in to comment.