diff --git a/raptor/ruge_stuben/par_interpolation.cpp b/raptor/ruge_stuben/par_interpolation.cpp index 33ace423..a37fa80e 100644 --- a/raptor/ruge_stuben/par_interpolation.cpp +++ b/raptor/ruge_stuben/par_interpolation.cpp @@ -1978,6 +1978,29 @@ ParBSRMatrix * one_point_interpolation(const ParBSRMatrix & A, return ret; } +template +using is_bsr_or_csr = std::enable_if_t || + std::is_same_v, bool>; + +template struct is_bsr : std::false_type {}; +template <> struct is_bsr : std::true_type {}; +template inline constexpr bool is_bsr_v = is_bsr::value; + +BSRMatrix & bsr_cast(Matrix &mat) { return dynamic_cast(mat); } +const BSRMatrix & bsr_cast(const Matrix & mat) { return dynamic_cast(mat); } + +template struct matrix_value; +template <> struct matrix_value { using type = double; }; +template <> struct matrix_value { using type = double*; }; +template +using matrix_value_t = typename matrix_value::type; + +template struct sequential_matrix; +template<> struct sequential_matrix { using type = BSRMatrix; }; +template<> struct sequential_matrix { using type = CSRMatrix; }; +template +using sequential_matrix_t = typename sequential_matrix::type; + namespace lair { namespace { @@ -1986,9 +2009,9 @@ namespace { Helper type providing access to received rows based on whether they are on_proc or off_proc */ -template +template struct comm_rows { - comm_rows(const ParCSRMatrix & A, + comm_rows(const T & A, CSRMatrix * mat); ~comm_rows() { if (rmat) delete rmat; @@ -2008,35 +2031,34 @@ struct comm_rows { std::vector idx; } diag, offd; - template struct row_view { - using value_type = std::conditional_t; + using value_type = matrix_value_t; auto & idx2() { return mat->idx2[off]; } value_type val() { - if constexpr (is_bsr) + if constexpr (is_bsr_v) return mat->block_vals[off]; else return mat->vals[off]; }; int off; - T * mat; + sequential_matrix_t * mat; }; private: comm_rows(CSRMatrix * mat); - using mat_t = std::conditional_t; template void iter(int row, F && f, const ptrs & pts) const { for (int i = pts.ptr[row]; i < pts.ptr[row + 1]; ++i) { - std::forward(f)(row_view{i, rmat}); + std::forward(f)(row_view{i, rmat}); } } - mat_t *rmat; + sequential_matrix_t *rmat; }; +template comm_rows(const T &, CSRMatrix *) -> comm_rows; struct offd_map_rowptr { std::vector diag_rowptr; @@ -2050,12 +2072,12 @@ struct offd_map_rowptr { This computes the rowptr for R and discovers set of off proc columns. Forward and backward column maps for off proc columns are also computed. */ -template +template = true> offd_map_rowptr map_offd_fill_rowptr(const ParCSRMatrix & S, const splitting_t & splitting, const std::vector & cpts, fpoint_distance distance, - const comm_rows & recv_rows) { + const comm_rows & recv_rows) { offd_map_rowptr ret; struct nnz_t { @@ -2130,15 +2152,11 @@ offd_map_rowptr map_offd_fill_rowptr(const ParCSRMatrix & S, return ret; } - -ParCSRMatrix * create_R(const ParCSRMatrix & A, - const ParCSRMatrix & S, - const splitting_t & splitting, - offd_map_rowptr && rowptr_and_colmap) { - - - bool isbsr = dynamic_cast(&A); - +template = true> +T * create_R(const T & A, + const ParCSRMatrix & S, + const splitting_t & splitting, + offd_map_rowptr && rowptr_and_colmap) { int local_rows = std::count(splitting.on_proc.cbegin(), splitting.on_proc.cend(), Selected); @@ -2154,20 +2172,19 @@ ParCSRMatrix * create_R(const ParCSRMatrix & A, mat.off_proc->idx1 = std::move(rac.offd_rowptr); mat.off_proc_column_map = std::move(rac.offd_colmap); mat.on_proc_column_map = A.on_proc_column_map; - [isbsr](auto & ... mats) { + [](auto & ... mats) { (mats.idx2.resize(mats.idx1.back()), ...); - if (isbsr) { - (dynamic_cast(mats).block_vals.resize( + if constexpr (is_bsr_v) { + (bsr_cast(mats).block_vals.resize( mats.idx1.back()), ...); } else (mats.vals.resize(mats.idx1.back()), ...); }(*mat.on_proc, *mat.off_proc); }; - if (isbsr) { - auto & bsr = dynamic_cast(A); + if constexpr (is_bsr_v) { auto * R = new ParBSRMatrix(S.partition, global_rows, S.global_num_cols, local_rows, S.on_proc_num_cols, off_proc_num_cols, - bsr.on_proc->b_rows, bsr.on_proc->b_cols); + A.on_proc->b_rows, A.on_proc->b_cols); move_data(std::move(rowptr_and_colmap), *R); return R; } else { @@ -2178,17 +2195,6 @@ ParCSRMatrix * create_R(const ParCSRMatrix & A, } } - -ParBSRMatrix * create_R(const ParBSRMatrix & A, - const ParCSRMatrix & S, - const splitting_t & splitting, - offd_map_rowptr && rowptr_and_colmap) { - return dynamic_cast( - create_R(dynamic_cast(A), - S, splitting, std::move(rowptr_and_colmap))); -} - - ParComm create_neighborhood_comm(const ParCSRMatrix & R, const ParCSRMatrix & A) { constexpr int tag = 9345; return ParComm(A.partition, R.off_proc_column_map, @@ -2209,10 +2215,10 @@ auto get_cpoints(const std::vector & split) { #include "pyamg_utils.hpp" -template +template = true> auto fill_colind(std::size_t row, const ParCSRMatrix & S, - const comm_rows & recv_neighbors, + const comm_rows & recv_neighbors, const splitting_t & splitting, const std::vector & cpts, fpoint_distance distance, @@ -2285,13 +2291,13 @@ auto fill_colind(std::size_t row, return std::make_pair(ind, ind_off); } -template +template struct row_searcher { - using matref = std::conditional_t; - using value_type = std::conditional_t; + using matref = const sequential_matrix_t &; + using value_type = matrix_value_t; - row_searcher(const ParCSRMatrix & A) : + row_searcher(const T & A) : diag(dynamic_cast(*A.on_proc)), offd(dynamic_cast(*A.off_proc)), diag_colmap(A.on_proc_column_map), @@ -2303,7 +2309,7 @@ struct row_searcher auto search = [&](matref mat, const std::vector & colmap) { for (int off = mat.idx1[local_row]; off < mat.idx1[local_row + 1]; ++off) { if (colmap[mat.idx2[off]] == global_col) { - if constexpr (is_bsr) + if constexpr (is_bsr_v) ret.emplace(mat.block_vals[off]); else ret.emplace(mat.vals[off]); @@ -2320,6 +2326,8 @@ struct row_searcher matref diag, offd; const std::vector & diag_colmap, offd_colmap; }; +template +row_searcher(const T&)->row_searcher; struct neighborhood_loop { template @@ -2335,7 +2343,7 @@ struct neighborhood_loop { const ParCSRMatrix & R; }; -template +template struct neighborhood_scan { template @@ -2347,7 +2355,7 @@ struct neighborhood_scan { }; auto offd_finder = [&](int i) { return [&, i, this](int col) { - std::optional::value_type> ret; + std::optional> ret; auto search = [&](auto rview) { if (rview.idx2() == col) @@ -2374,27 +2382,26 @@ struct neighborhood_scan { } const neighborhood_loop & loop; - const comm_rows & recv_rows; - const row_searcher & row_search; + const comm_rows & recv_rows; + const row_searcher & row_search; const std::map & offd_g2l; }; -template -neighborhood_scan(const neighborhood_loop&, const comm_rows &, - const row_searcher&, const std::map&)->neighborhood_scan; +template +neighborhood_scan(const neighborhood_loop&, const comm_rows &, + const row_searcher&, const std::map&)->neighborhood_scan; -template +template void fill_data(std::size_t row, int cpoint, int ind, int ind_off, - const ParCSRMatrix &A, const comm_rows &recv_rows, + const T &A, const comm_rows &recv_rows, ParCSRMatrix &R, const std::map &offd_g2l); template <> -void fill_data(std::size_t row, int cpoint, int ind, int ind_off, - const ParCSRMatrix &Acsr, - const comm_rows &recv_rows, - ParCSRMatrix &R, - const std::map &offd_g2l) +void fill_data(std::size_t row, int cpoint, int ind, int ind_off, + const ParBSRMatrix &A, + const comm_rows &recv_rows, + ParCSRMatrix &R, + const std::map &offd_g2l) { - const auto & A = dynamic_cast(Acsr); // assuming b_rows == b_cols auto blocksize = A.on_proc->b_rows; // Build local linear system as the submatrix A^T restricted to the neighborhood, @@ -2406,7 +2413,7 @@ void fill_data(std::size_t row, int cpoint, int ind, int ind_off, auto num_dofs = size_n * blocksize; std::vector A0(num_dofs*num_dofs, 0.0); - row_searcher row_search(A); + row_searcher row_search(A); neighborhood_loop neig_loop{row, ind, ind_off, R}; neighborhood_scan iter_neig{neig_loop, recv_rows, row_search, offd_g2l}; @@ -2510,15 +2517,15 @@ void fill_data(std::size_t row, int cpoint, int ind, int ind_off, block_ind = 0; neig_loop( [&](int off) { - dynamic_cast(*R.on_proc).block_vals[off] = get_vals(block_ind++); + bsr_cast(*R.on_proc).block_vals[off] = get_vals(block_ind++); }, [&](int off) { - dynamic_cast(*R.off_proc).block_vals[off] = get_vals(block_ind++); + bsr_cast(*R.off_proc).block_vals[off] = get_vals(block_ind++); }); // Add identity for C-point in this block row (assume data[] initialized to 0) R.on_proc->idx2[ind] = cpoint; - dynamic_cast(*R.on_proc).block_vals[ind] = + bsr_cast(*R.on_proc).block_vals[ind] = [blocksize](){ double * ident_vals = new double[blocksize*blocksize](); for (int this_row = 0; this_row < blocksize; ++this_row) { @@ -2530,11 +2537,11 @@ void fill_data(std::size_t row, int cpoint, int ind, int ind_off, template <> -void fill_data(std::size_t row, int cpoint, int ind, int ind_off, - const ParCSRMatrix &A, - const comm_rows &recv_rows, - ParCSRMatrix &R, - const std::map &offd_g2l) +void fill_data(std::size_t row, int cpoint, int ind, int ind_off, + const ParCSRMatrix &A, + const comm_rows &recv_rows, + ParCSRMatrix &R, + const std::map &offd_g2l) { // Build local linear system as the submatrix A restricted to the neighborhood, // Nf, of strongly connected F-points to the current C-point, that is A0 = @@ -2544,7 +2551,7 @@ void fill_data(std::size_t row, int cpoint, int ind, int ind_off, std::vector A0; A0.reserve(size_n * size_n); - row_searcher row_search(A); + row_searcher row_search(A); neighborhood_loop neig_loop{row, ind, ind_off, R}; neighborhood_scan iter_neig{neig_loop, recv_rows, row_search, offd_g2l}; @@ -2594,11 +2601,11 @@ void fill_data(std::size_t row, int cpoint, int ind, int ind_off, } -template -void fill_colind_and_data(const ParCSRMatrix & A, - const ParCSRMatrix &S, - const comm_rows & recv_neighbors, - const comm_rows &recv_rows, +template = true> +void fill_colind_and_data(const T & A, + const ParCSRMatrix & S, + const comm_rows & recv_neighbors, + const comm_rows &recv_rows, const splitting_t &splitting, const std::vector &cpts, fpoint_distance distance, @@ -2616,7 +2623,7 @@ void fill_colind_and_data(const ParCSRMatrix & A, auto cpoint = cpts[row]; auto [ind, ind_off] = fill_colind(row, S, recv_neighbors, splitting, cpts, distance, R); - fill_data(row, cpoint, ind, ind_off, A, recv_rows, R, offd_g2l); + fill_data(row, cpoint, ind, ind_off, A, recv_rows, R, offd_g2l); } // offd colinds are currently global, convert to local @@ -2655,16 +2662,13 @@ struct mat_value For each local row send strongly connected f-point neighbor column indices if said row is an f-point. */ -template -CSRMatrix * communicate_neighborhood(const ParCSRMatrix & A, const ParCSRMatrix & S, +template = true> +CSRMatrix * communicate_neighborhood(const T & A, const ParCSRMatrix & S, const splitting_t & split, C && comm) { - using bsr_t = std::vector; - using csr_t = std::vector; - using val_t = std::conditional_t; - std::vector rowptr(A.local_num_rows + 1); - std::vector colind; - val_t values; - + using val_t = std::vector>; + std::vector rowptr(A.local_num_rows + 1); + std::vector colind; + val_t values; if (A.local_nnz) { [&](auto & ... v) { (v.reserve(A.local_nnz),...); }(colind, values); @@ -2680,7 +2684,7 @@ CSRMatrix * communicate_neighborhood(const ParCSRMatrix & A, const ParCSRMatrix const std::vector & colmap, const Matrix & s, const std::vector & splitting) { - detail::mat_value mat_value(a); + detail::mat_value> mat_value(a); auto [beg, end] = get_bounds(a); auto [ctr_s, end_s] = get_bounds(s); @@ -2715,17 +2719,9 @@ CSRMatrix * communicate_neighborhood(const ParCSRMatrix & A, const ParCSRMatrix A.on_proc->b_rows, A.on_proc->b_cols); } - -template <> -comm_rows::comm_rows(CSRMatrix *mat) - : rmat(dynamic_cast(mat)) {} -template <> -comm_rows::comm_rows(CSRMatrix *mat) - : rmat(mat) {} - -template -comm_rows::comm_rows(const ParCSRMatrix &A, CSRMatrix *mat) - : comm_rows(mat) +template +comm_rows::comm_rows(const T & A, CSRMatrix * mat) + : rmat(dynamic_cast*>(mat)) { if (!rmat) return; @@ -2763,15 +2759,11 @@ comm_rows::comm_rows(const ParCSRMatrix &A, CSRMatrix *mat) [](auto & ... v) { ((v.shrink_to_fit()), ...); }(diag.idx, offd.idx); } -template -using par_mat = std::conditional_t; - -template -par_mat * compute_R(par_mat & A, - ParCSRMatrix & S, - const splitting_t & splitting, - fpoint_distance distance) { - using msg_t = comm_rows; +template = true> +T * compute_R(T & A, + ParCSRMatrix & S, + const splitting_t & splitting, + fpoint_distance distance) { auto pre_init = [](auto & mat) { mat.sort(); mat.on_proc->move_diag(); @@ -2779,19 +2771,19 @@ par_mat * compute_R(par_mat & A, pre_init(A); pre_init(S); - msg_t recv_neighbors(A, - (distance == fpoint_distance::two) ? - communicate_neighborhood(A, S, splitting, *A.comm) : nullptr); + comm_rows recv_neighbors(A, + (distance == fpoint_distance::two) ? + communicate_neighborhood(A, S, splitting, *A.comm) : nullptr); auto cpts = get_cpoints(splitting.on_proc); auto R = create_R(A, S, splitting, map_offd_fill_rowptr( S, splitting, cpts, distance, recv_neighbors)); - msg_t recv_rows(A, communicate_neighborhood(A, S, splitting, - create_neighborhood_comm(*R, A))); + comm_rows recv_rows(A, communicate_neighborhood(A, S, splitting, + create_neighborhood_comm(*R, A))); - fill_colind_and_data(A, S, recv_neighbors, recv_rows, splitting, cpts, distance, *R); + fill_colind_and_data(A, S, recv_neighbors, recv_rows, splitting, cpts, distance, *R); constexpr int tag = 9244; R->comm = new ParComm(R->partition, R->off_proc_column_map, R->on_proc_column_map, @@ -2807,7 +2799,7 @@ ParCSRMatrix * local_air(ParCSRMatrix & A, const splitting_t & splitting, fpoint_distance distance) { - return lair::compute_R(A, S, splitting, distance); + return lair::compute_R(A, S, splitting, distance); } @@ -2815,7 +2807,7 @@ ParBSRMatrix * local_air(ParBSRMatrix & A, ParCSRMatrix & S, const splitting_t & splitting, fpoint_distance distance) { - return lair::compute_R(A, S, splitting, distance); + return lair::compute_R(A, S, splitting, distance); } } // namespace raptor