From 6f963f6094ab45cfbf88e4fca342b253c387daaa Mon Sep 17 00:00:00 2001 From: Andrew Reisner Date: Fri, 13 Sep 2024 14:27:37 -0600 Subject: [PATCH] Add ParBSR matmult. --- raptor/core/matrix.hpp | 4 +- raptor/ruge_stuben/par_air_solver.cpp | 2 +- raptor/util/linalg/add.cpp | 92 +++++---- raptor/util/linalg/par_matmult.cpp | 260 ++++++++++++++++---------- 4 files changed, 227 insertions(+), 131 deletions(-) diff --git a/raptor/core/matrix.hpp b/raptor/core/matrix.hpp index ecbd3838..947bb092 100644 --- a/raptor/core/matrix.hpp +++ b/raptor/core/matrix.hpp @@ -1039,7 +1039,9 @@ class BSRMatrix : public CSRMatrix void spmv_append_T(const double* x, double* b) const; void spmv_append_neg(const double* x, double* b) const; void spmv_append_neg_T(const double* x, double* b) const; - void spmv_residual(const double* x, const double* b, double* r) const; + void spmv_residual(const double* x, const double* b, double* r) const; + + void add_append(BSRMatrix * A, BSRMatrix * C, bool remove_dup = true); format_t format() { diff --git a/raptor/ruge_stuben/par_air_solver.cpp b/raptor/ruge_stuben/par_air_solver.cpp index 6e1f2fdd..d8543f86 100644 --- a/raptor/ruge_stuben/par_air_solver.cpp +++ b/raptor/ruge_stuben/par_air_solver.cpp @@ -18,7 +18,7 @@ void ParAIRSolver::extend_hier(T & A) { levels.back()->R = R; levels.emplace_back(new ParLevel()); auto & level = *levels.back(); - auto AP = A.mult(P, tap_level); + auto AP = A.mult(P); auto coarse_A = R->mult(AP); level.A = coarse_A; diff --git a/raptor/util/linalg/add.cpp b/raptor/util/linalg/add.cpp index 765c91e5..edeb5b2f 100644 --- a/raptor/util/linalg/add.cpp +++ b/raptor/util/linalg/add.cpp @@ -36,45 +36,73 @@ CSRMatrix* CSRMatrix::add(CSRMatrix* B, bool remove_dup) return C; } -void CSRMatrix::add_append(CSRMatrix* B, CSRMatrix* C, bool remove_dup) +namespace impl { + +template struct is_bsr : std::false_type {}; +template <> struct is_bsr : std::true_type {}; +template inline constexpr bool is_bsr_v = is_bsr::value; + +template || + std::is_same_v, + bool> = true> +void add_append(const T & A, const T & B, T & C, bool remove_dup) { - int start, end; + auto vals = [](auto & mat) -> auto & { + if constexpr (is_bsr_v) return mat.block_vals; + else return mat.vals; + }; + C.resize(A.n_rows, A.n_cols); + int C_nnz = A.nnz + B.nnz; + C.idx2.resize(C_nnz); + vals(C).resize(C_nnz); - C->resize(n_rows, n_cols); - int C_nnz = nnz + B->nnz; - C->idx2.resize(C_nnz); - C->vals.resize(C_nnz); + auto copy_vals = [b_size = A.b_size](auto beg, auto end, auto out) { + (void) b_size; + if constexpr (is_bsr_v) { + for (; beg != end; ++beg, ++out) { + auto val = new double[b_size]; + std::copy(val, val + b_size, *beg); + *out = val; + } + } else std::copy(beg, end, out); + }; C_nnz = 0; - C->idx1[0] = 0; - for (int i = 0; i < n_rows; i++) + C.idx1[0] = 0; + for (int i = 0; i < A.n_rows; i++) { - start = idx1[i]; - end = idx1[i+1]; - std::copy(idx2.begin() + start, - idx2.begin() + end, - C->idx2.begin() + C_nnz); - std::copy(vals.begin() + start, - vals.begin() + end, - C->vals.begin() + C_nnz); - C_nnz += (end - start); + auto add_row = [&](const T & src, T & dst) { + auto start = src.idx1[i]; + auto end = src.idx1[i+1]; + std::copy(src.idx2.begin() + start, + src.idx2.begin() + end, + dst.idx2.begin() + C_nnz); + copy_vals(vals(src).begin() + start, + vals(src).begin() + end, + vals(dst).begin() + C_nnz); + return end - start; + }; - start = B->idx1[i]; - end = B->idx1[i+1]; - std::copy(B->idx2.begin() + start, - B->idx2.begin() + end, - C->idx2.begin() + C_nnz); - std::copy(B->vals.begin() + start, - B->vals.begin() + end, - C->vals.begin() + C_nnz); - C_nnz += (end - start); - - C->idx1[i+1] = C_nnz; + C_nnz += add_row(A, C); + C_nnz += add_row(B, C); + + C.idx1[i+1] = C_nnz; } - C->nnz = C_nnz; - C->sort(); - if (remove_dup) - C->remove_duplicates(); + C.nnz = C_nnz; + C.sort(); + if (remove_dup) + C.remove_duplicates(); +} +} + +void CSRMatrix::add_append(CSRMatrix* B, CSRMatrix* C, bool remove_dup) +{ + impl::add_append(*this, *B, *C, remove_dup); +} + +void BSRMatrix::add_append(BSRMatrix * B, BSRMatrix * C, bool remove_dup) +{ + impl::add_append(*this, *B, *C, remove_dup); } CSRMatrix* CSRMatrix::subtract(CSRMatrix* B) diff --git a/raptor/util/linalg/par_matmult.cpp b/raptor/util/linalg/par_matmult.cpp index e86c2f2d..9a83f29e 100644 --- a/raptor/util/linalg/par_matmult.cpp +++ b/raptor/util/linalg/par_matmult.cpp @@ -1,6 +1,7 @@ // Copyright (c) 2015-2017, RAPtor Developer Team // License: Simplified BSD, http://opensource.org/licenses/BSD-2-Clause #include "raptor/core/par_matrix.hpp" +#include "raptor/core/matrix_traits.hpp" using namespace raptor; @@ -49,7 +50,7 @@ ParCSRMatrix* init_matrix(T* A, U* B) if (A->partition == B->partition) { - C = init_mat(A); + C = init_mat(A); } else { @@ -76,40 +77,57 @@ ParCSRMatrix* init_matrix(T* A, U* B) return C; } -ParCSRMatrix* ParCSRMatrix::mult(ParCSRMatrix* B, bool tap) +template = true> +T * matmult(T & A, T & B) { - if (tap) - { - return this->tap_mult(B); - } - - // Check that communication package has been initialized - if (comm == NULL) + // Check that communication package has been initialized + if (A.comm == NULL) { - comm = new ParComm(partition, off_proc_column_map, on_proc_column_map); + A.comm = new ParComm(A.partition, A.off_proc_column_map, A.on_proc_column_map); } // Initialize C (matrix to be returned) - ParCSRMatrix* C = init_matrix(this, B); + auto C = init_matrix(&A, &B); std::vector send_buffer; // Communicate data and multiply - comm->init_par_mat_comm(B, send_buffer); + A.comm->init_par_mat_comm(&B, send_buffer); + using seq_t = sequential_matrix_t; // Fully Local Computation - CSRMatrix* C_on_on = on_proc->mult((CSRMatrix*) B->on_proc); - CSRMatrix* C_on_off = on_proc->mult((CSRMatrix*) B->off_proc); + auto C_on_on = A.on_proc->mult(dynamic_cast(B.on_proc)); + auto C_on_off = A.on_proc->mult(dynamic_cast(B.off_proc)); - CSRMatrix* recv_mat = comm->complete_mat_comm(); + auto recv_mat = A.comm->complete_mat_comm(A.on_proc->b_rows, A.on_proc->b_cols); - mult_helper(B, C, recv_mat, C_on_on, C_on_off); + A.mult_helper(&B, C, recv_mat, C_on_on, C_on_off); delete C_on_on; delete C_on_off; delete recv_mat; // Return matrix containing product - return C; + if constexpr (is_bsr_v) + return dynamic_cast(C); + else + return C; +} + +ParCSRMatrix* ParCSRMatrix::mult(ParCSRMatrix* B, bool tap) +{ + if (tap) + { + return this->tap_mult(B); + } + + auto A_bsr = dynamic_cast(this); + auto B_bsr = dynamic_cast(B); + if (A_bsr && B_bsr) { + assert((A_bsr->on_proc->b_rows == B_bsr->on_proc->b_rows) && + (A_bsr->on_proc->b_cols == B_bsr->on_proc->b_cols)); + return matmult(*A_bsr, *B_bsr); + } else + return matmult(*this, *B); } ParCSRMatrix* ParCSRMatrix::tap_mult(ParCSRMatrix* B) @@ -118,7 +136,7 @@ ParCSRMatrix* ParCSRMatrix::tap_mult(ParCSRMatrix* B) if (tap_mat_comm == NULL) { // Always 2-step - tap_mat_comm = new TAPComm(partition, off_proc_column_map, + tap_mat_comm = new TAPComm(partition, off_proc_column_map, on_proc_column_map, false); } @@ -178,7 +196,7 @@ ParCSRMatrix* ParCSRMatrix::mult_T(ParCSCMatrix* A, bool tap) CSRMatrix* Ctmp = mult_T_partial(A); std::vector send_buffer; - A->comm->init_mat_comm_T(send_buffer, Ctmp->idx1, Ctmp->idx2, + A->comm->init_mat_comm_T(send_buffer, Ctmp->idx1, Ctmp->idx2, Ctmp->vals); CSRMatrix* C_on_on = on_proc->mult_T((CSCMatrix*) A->on_proc); @@ -202,7 +220,7 @@ ParCSRMatrix* ParCSRMatrix::tap_mult_T(ParCSCMatrix* A) { if (A->tap_mat_comm == NULL) { - A->tap_mat_comm = new TAPComm(A->partition, A->off_proc_column_map, + A->tap_mat_comm = new TAPComm(A->partition, A->off_proc_column_map, A->on_proc_column_map, false); } @@ -212,7 +230,7 @@ ParCSRMatrix* ParCSRMatrix::tap_mult_T(ParCSCMatrix* A) CSRMatrix* Ctmp = mult_T_partial(A); std::vector send_buffer; - A->tap_mat_comm->init_mat_comm_T(send_buffer, Ctmp->idx1, Ctmp->idx2, + A->tap_mat_comm->init_mat_comm_T(send_buffer, Ctmp->idx1, Ctmp->idx2, Ctmp->vals); CSRMatrix* C_on_on = on_proc->mult_T((CSCMatrix*) A->on_proc); @@ -236,130 +254,179 @@ ParMatrix* ParMatrix::mult(ParCSRMatrix* B, bool tap) { int rank; RAPtor_MPI_Comm_rank(RAPtor_MPI_COMM_WORLD, &rank); - if (rank == 0) + if (rank == 0) printf("Multiplication is not implemented for these ParMatrix types.\n"); return NULL; } -void ParCSRMatrix::mult_helper(ParCSRMatrix* B, ParCSRMatrix* C, - CSRMatrix* recv_mat, CSRMatrix* C_on_on, CSRMatrix* C_on_off) -{ +template = true> +void matmult_helper(T & A, T & B, T & C, + sequential_matrix_t & recv_mat, + sequential_matrix_t & C_on_on, + sequential_matrix_t & C_on_off) { // Set dimensions of C - C->global_num_rows = global_num_rows; - C->global_num_cols = B->global_num_cols; - C->local_num_rows = local_num_rows; + C.global_num_rows = A.global_num_rows; + C.global_num_cols = B.global_num_cols; + C.local_num_rows = A.local_num_rows; + + C.on_proc_column_map = B.get_on_proc_column_map(); + C.local_row_map = A.get_local_row_map(); + C.on_proc_num_cols = C.on_proc_column_map.size(); - C->on_proc_column_map = B->get_on_proc_column_map(); - C->local_row_map = get_local_row_map(); - C->on_proc_num_cols = C->on_proc_column_map.size(); - // Initialize nnz as 0 (will increment this as nonzeros are added) - C->local_nnz = 0; + C.local_nnz = 0; // Declare Variables int row_start, row_end; int global_col; - - // Split recv_mat into on and off proc portions - CSRMatrix* recv_on = new CSRMatrix(recv_mat->n_rows, -1); - CSRMatrix* recv_off = new CSRMatrix(recv_mat->n_rows, -1); - auto part_to_col = B->map_partition_to_local(); - recv_on->idx1[0] = 0; - recv_off->idx1[0] = 0; - for (int i = 0; i < recv_mat->n_rows; i++) + // Split recv_mat into on and off proc portions + auto create_mat = [&]() { + if constexpr (is_bsr_v) + return BSRMatrix(recv_mat.n_rows, -1, + recv_mat.b_rows, recv_mat.b_cols); + else return CSRMatrix(recv_mat.n_rows, -1); + }; + auto recv_on = create_mat(); + auto recv_off = create_mat(); + + auto append_val = [](const sequential_matrix_t & src, + sequential_matrix_t & dst, + int offset) { + if constexpr (is_bsr_v) { + double * new_val = new double[src.b_size]; + std::copy(src.block_vals[offset], src.block_vals[offset] + src.b_size, new_val); + dst.block_vals.emplace_back(new_val); + } else { + dst.vals.emplace_back(src.vals[offset]); + } + }; + auto part_to_col = B.map_partition_to_local(); + recv_on.idx1[0] = 0; + recv_off.idx1[0] = 0; + for (int i = 0; i < recv_mat.n_rows; i++) { - row_start = recv_mat->idx1[i]; - row_end = recv_mat->idx1[i+1]; + row_start = recv_mat.idx1[i]; + row_end = recv_mat.idx1[i+1]; for (int j = row_start; j < row_end; j++) { - global_col = recv_mat->idx2[j]; - if (global_col < B->partition->first_local_col || - global_col > B->partition->last_local_col) + global_col = recv_mat.idx2[j]; + if (global_col < B.partition->first_local_col || + global_col > B.partition->last_local_col) { - recv_off->idx2.emplace_back(global_col); - recv_off->vals.emplace_back(recv_mat->vals[j]); + recv_off.idx2.emplace_back(global_col); + append_val(recv_mat, recv_off, j); } else { - recv_on->idx2.emplace_back(part_to_col[global_col - - B->partition->first_local_col]); - recv_on->vals.emplace_back(recv_mat->vals[j]); + recv_on.idx2.emplace_back(part_to_col[global_col - + B.partition->first_local_col]); + append_val(recv_mat, recv_on, j); } } - recv_on->idx1[i+1] = recv_on->idx2.size(); - recv_off->idx1[i+1] = recv_off->idx2.size(); + recv_on.idx1[i+1] = recv_on.idx2.size(); + recv_off.idx1[i+1] = recv_off.idx2.size(); } - recv_on->nnz = recv_on->idx2.size(); - recv_off->nnz = recv_off->idx2.size(); + recv_on.nnz = recv_on.idx2.size(); + recv_off.nnz = recv_off.idx2.size(); // Calculate global_to_C and B_to_C column maps std::map global_to_C; - std::vector B_to_C(B->off_proc_num_cols); + std::vector B_to_C(B.off_proc_num_cols); - std::copy(recv_off->idx2.begin(), recv_off->idx2.end(), - std::back_inserter(C->off_proc_column_map)); - for (std::vector::iterator it = B->off_proc_column_map.begin(); - it != B->off_proc_column_map.end(); ++it) + std::copy(recv_off.idx2.begin(), recv_off.idx2.end(), + std::back_inserter(C.off_proc_column_map)); + for (std::vector::iterator it = B.off_proc_column_map.begin(); + it != B.off_proc_column_map.end(); ++it) { - C->off_proc_column_map.emplace_back(*it); + C.off_proc_column_map.emplace_back(*it); } - std::sort(C->off_proc_column_map.begin(), C->off_proc_column_map.end()); + std::sort(C.off_proc_column_map.begin(), C.off_proc_column_map.end()); int prev_col = -1; - C->off_proc_num_cols = 0; - for (std::vector::iterator it = C->off_proc_column_map.begin(); - it != C->off_proc_column_map.end(); ++it) + C.off_proc_num_cols = 0; + for (std::vector::iterator it = C.off_proc_column_map.begin(); + it != C.off_proc_column_map.end(); ++it) { if (*it != prev_col) { - global_to_C[*it] = C->off_proc_num_cols; - C->off_proc_column_map[C->off_proc_num_cols++] = *it; + global_to_C[*it] = C.off_proc_num_cols; + C.off_proc_column_map[C.off_proc_num_cols++] = *it; prev_col = *it; } } - C->off_proc_column_map.resize(C->off_proc_num_cols); + C.off_proc_column_map.resize(C.off_proc_num_cols); - for (int i = 0; i < B->off_proc_num_cols; i++) + for (int i = 0; i < B.off_proc_num_cols; i++) { - global_col = B->off_proc_column_map[i]; + global_col = B.off_proc_column_map[i]; B_to_C[i] = global_to_C[global_col]; } - for (std::vector::iterator it = recv_off->idx2.begin(); - it != recv_off->idx2.end(); ++it) + for (std::vector::iterator it = recv_off.idx2.begin(); + it != recv_off.idx2.end(); ++it) { *it = global_to_C[*it]; } - for (std::vector::iterator it = C_on_off->idx2.begin(); - it != C_on_off->idx2.end(); ++it) + for (std::vector::iterator it = C_on_off.idx2.begin(); + it != C_on_off.idx2.end(); ++it) { *it = B_to_C[*it]; } - C->off_proc_num_cols = C->off_proc_column_map.size(); - recv_on->n_cols = B->on_proc->n_cols; - recv_off->n_cols = C->off_proc_num_cols; - C_on_off->n_cols = C->off_proc_num_cols; - - // Multiply A->off_proc * B->recv_on -> C_off_on - CSRMatrix* C_off_on = off_proc->mult(recv_on); - delete recv_on; - - // Multiply A->off_proc * B->recv_off -> C_off_off - CSRMatrix* C_off_off = off_proc->mult(recv_off); - delete recv_off; + C.off_proc_num_cols = C.off_proc_column_map.size(); + recv_on.n_cols = B.on_proc->n_cols; + recv_off.n_cols = C.off_proc_num_cols; + C_on_off.n_cols = C.off_proc_num_cols; + + auto cast_ptr = [](Matrix * ptr) { + auto ret = dynamic_cast*>(ptr); + assert(ret); + return ret; + }; + // Multiply A.off_proc * B.recv_on -> C_off_on + auto C_off_on = cast_ptr(A.off_proc->mult(&recv_on)); + + // Multiply A.off_proc * B.recv_off -> C_off_off + auto C_off_off = cast_ptr(A.off_proc->mult(&recv_off)); // Create C->on_proc by adding C_on_on + C_off_on - C_on_on->add_append(C_off_on, (CSRMatrix*) C->on_proc); + C_on_on.add_append(C_off_on, cast_ptr(C.on_proc)); delete C_off_on; - // Create C->off_proc by adding C_off_on + C_off_off - C_on_off->add_append(C_off_off, (CSRMatrix*) C->off_proc); + // Create C.off_proc by adding C_off_on + C_off_off + C_on_off.add_append(C_off_off, cast_ptr(C.off_proc)); delete C_off_off; - C->local_nnz = C->on_proc->nnz + C->off_proc->nnz; + C.local_nnz = C.on_proc->nnz + C.off_proc->nnz; } + +void ParCSRMatrix::mult_helper(ParCSRMatrix* B, ParCSRMatrix* C, + CSRMatrix* recv_mat, CSRMatrix* C_on_on, CSRMatrix* C_on_off) +{ + auto all_par_bsr = [](auto ... p) { + return (dynamic_cast(p) && ...); + }; + auto all_bsr = [](auto ... p) { + return (dynamic_cast(p) && ...); + }; + + if (all_par_bsr(this, B, C) && all_bsr(recv_mat, C_on_on, C_on_off)) { + [](auto & ... m) { + auto cast = [](auto & v) -> auto & { + if constexpr (std::is_same_v) + return dynamic_cast(v); + else if constexpr (std::is_same_v) + return dynamic_cast(v); + }; + matmult_helper(cast(m)...); + }(*this, *B, *C, *recv_mat, *C_on_on, *C_on_off); + } else { + matmult_helper(*this, *B, *C, *recv_mat, *C_on_on, *C_on_off); + } +} + + CSRMatrix* ParCSRMatrix::mult_T_partial(CSCMatrix* A_off) { CSRMatrix* C_off_on = on_proc->mult_T(A_off, on_proc_column_map.data()); @@ -376,12 +443,12 @@ CSRMatrix* ParCSRMatrix::mult_T_partial(CSCMatrix* A_off) CSRMatrix* ParCSRMatrix::mult_T_partial(ParCSCMatrix* A) { // Declare Variables - return mult_T_partial((CSCMatrix*) A->off_proc); + return mult_T_partial((CSCMatrix*) A->off_proc); } void ParCSRMatrix::mult_T_combine(ParCSCMatrix* P, ParCSRMatrix* C, CSRMatrix* recv_mat, CSRMatrix* C_on_on, CSRMatrix* C_off_on) -{ +{ int start, end, ctr; int col, col_C; @@ -442,7 +509,7 @@ void ParCSRMatrix::mult_T_combine(ParCSCMatrix* P, ParCSRMatrix* C, CSRMatrix* r *it = part_to_col[(*it - partition->first_local_col)]; } - // Multiply on_proc + // Multiply on_proc recv_on->n_cols = C->on_proc_num_cols; C_on_on->add_append(recv_on, (CSRMatrix*) C->on_proc); @@ -459,12 +526,12 @@ void ParCSRMatrix::mult_T_combine(ParCSCMatrix* P, ParCSRMatrix* C, CSRMatrix* r // Create set of global columns in B_off_proc and recv_mat std::set C_col_set; - for (std::vector::iterator it = recv_off->idx2.begin(); + for (std::vector::iterator it = recv_off->idx2.begin(); it != recv_off->idx2.end(); ++it) { C_col_set.insert(*it); } - for (std::vector::iterator it = off_proc_column_map.begin(); + for (std::vector::iterator it = off_proc_column_map.begin(); it != off_proc_column_map.end(); ++it) { C_col_set.insert(*it); @@ -475,7 +542,7 @@ void ParCSRMatrix::mult_T_combine(ParCSCMatrix* P, ParCSRMatrix* C, CSRMatrix* r { C->off_proc_column_map.reserve(C->off_proc_num_cols); } - for (std::set::iterator it = C_col_set.begin(); + for (std::set::iterator it = C_col_set.begin(); it != C_col_set.end(); ++it) { global_to_C[*it] = C->off_proc_column_map.size(); @@ -558,4 +625,3 @@ void ParCSRMatrix::mult_T_combine(ParCSCMatrix* P, ParCSRMatrix* C, CSRMatrix* r delete recv_on; delete recv_off; } -