From 3268c0bfa83b77ee1bee042c23e81ab6d7456241 Mon Sep 17 00:00:00 2001 From: Amanda Bienz Date: Wed, 22 May 2024 21:38:47 -0600 Subject: [PATCH] Added test for CSR to BSR conversion --- raptor/core/matrix.hpp | 6 +- raptor/core/tests/CMakeLists.txt | 4 +- raptor/core/tests/test_bsr_matrix.cpp | 81 +++++++++++++++++---------- raptor/util/tests/CMakeLists.txt | 6 +- 4 files changed, 61 insertions(+), 36 deletions(-) diff --git a/raptor/core/matrix.hpp b/raptor/core/matrix.hpp index 5d866b07..40df47d7 100644 --- a/raptor/core/matrix.hpp +++ b/raptor/core/matrix.hpp @@ -1007,13 +1007,13 @@ class BSRMatrix : public CSRMatrix { for (int block = 0; block < b_rows; block++) { - int csr_row = bsr_row+block; + int csr_row = bsr_row*b_rows+block; for (int j = A->idx1[csr_row]; j < A->idx1[csr_row+1]; j++) { int csr_col = A->idx2[j]; int bsr_col = csr_col / b_rows; - if (idx[bsr_col] != -1) + if (idx[bsr_col] == -1) { idx[bsr_col] = idx2.size(); idx2.push_back(bsr_col); @@ -1023,7 +1023,7 @@ class BSRMatrix : public CSRMatrix int idx_col = csr_col % b_cols; block_vals[idx[bsr_col]][idx_row*b_rows + idx_col] = A->vals[j]; } - } + } idx1[bsr_row+1] = idx2.size(); // Reset IDX array for next BSR row diff --git a/raptor/core/tests/CMakeLists.txt b/raptor/core/tests/CMakeLists.txt index 74026690..0d879f81 100644 --- a/raptor/core/tests/CMakeLists.txt +++ b/raptor/core/tests/CMakeLists.txt @@ -51,4 +51,6 @@ add_executable(test_transpose test_transpose.cpp) target_link_libraries(test_transpose raptor ${MPI_LIBRARIES} googletest pthread ) add_test(TransposeTest ./test_transpose) - +add_executable(test_bsr_matrix test_bsr_matrix.cpp) +target_link_libraries(test_bsr_matrix raptor ${MPI_LIBRARIES} googletest pthread ) +add_test(BSRMatrixTest ./test_bsr_matrix) diff --git a/raptor/core/tests/test_bsr_matrix.cpp b/raptor/core/tests/test_bsr_matrix.cpp index fd22d3d0..850a66da 100644 --- a/raptor/core/tests/test_bsr_matrix.cpp +++ b/raptor/core/tests/test_bsr_matrix.cpp @@ -15,42 +15,65 @@ int main(int argc, char** argv) TEST(BSRMatrixTest, TestsInCore) { -/* - std::vector> indx = {{0,0}, {0,1}, {1,1}, {2,1}, {2,2}}; + // Matrix [0, 1], [1, 0] + // [2, 0], [0, 2] + // [3, 0], [0, 0] + // [0, 4], [0, 0] + int n_csr = 4; + int nnz_csr = 6; + std::vector rowptr_csr = {0, 2, 4, 5, 6}; + std::vector col_idx_csr = {1, 2, 0, 3, 0, 1}; + std::vector data_csr = {1, 1, 2, 2, 3, 4}; + CSRMatrix* A_csr = new CSRMatrix(n_csr, n_csr, rowptr_csr, col_idx_csr, data_csr); + + int n = 2; // 2 blocks by 2 blocks + int br = 2; // blocks are each 2x2 + int bs = 4; + int nnz = 3; // 3 blocks + std::vector rowptr = {0, 2, 3}; + std::vector col_idx = {0, 1, 0}; + std::vector data = {0, 1, 2, 0, 1, 0, 0, 2, 3, 0, 0, 4}; - int rows_in_block = 2; - int cols_in_block = 2; - int n = 6; - - // Create BSR Matrices (6x6) - const BSRMatrix A_BSR1(n, n, rows_in_block, cols_in_block, row_ptr, cols, vals); - BSRMatrix A_BSR2(n, n, rows_in_block, cols_in_block); - - // Add blocks - for(int i=0; iidx1[0] = 0; + for (int i = 0; i < n; i++) + { + A->idx1[i+1] = rowptr[i+1]; + for (int j = A->idx1[i]; j < A->idx1[i+1]; j++) + { + A->idx2.push_back(col_idx[j]); + double* vals = new double[bs]; + for (int k = 0; k < bs; k++) + vals[k] = data[j*bs + k]; + A->block_vals.push_back(vals); + } } - // Check dimensions of A_BSR2 - ASSERT_EQ(A_BSR2.nnz, A_BSR1.nnz); - //ASSERT_EQ(A_BSR2.n_blocks, A_BSR2.n_blocks); + // Call method that converts CSR to BSR + BSRMatrix* A_conv = new BSRMatrix(A_csr, br, br); - // Check row_ptr - for (int i=0; in_rows, A->n_rows); + ASSERT_EQ(A_conv->n_cols, A->n_cols); + ASSERT_EQ(A_conv->b_rows, A->b_rows); + ASSERT_EQ(A_conv->b_cols, A->b_cols); + ASSERT_EQ(A_conv->b_size, A->b_size); - // Check column indices - for (int i=0; in_rows; i++) { - ASSERT_EQ(A_BSR2.idx2[i], A_BSR1.idx2[i]); + ASSERT_EQ(A_conv->idx1[i+1], A->idx1[i+1]); + for (int j = A->idx1[i]; j < A->idx1[i+1]; j++) + { + ASSERT_EQ(A_conv->idx2[j], A->idx2[j]); + for (int k = 0; k < A->b_size; k++) + ASSERT_EQ(A_conv->block_vals[j][k], A->block_vals[j][k]); + } } + + delete A_csr; + delete A; + delete A_conv; - // Check data - //for (int i=0; i