Skip to content

Commit

Permalink
Add ParBSR matmult.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewreisner committed Sep 13, 2024
1 parent 579fc65 commit 6f963f6
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 131 deletions.
4 changes: 3 additions & 1 deletion raptor/core/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,9 @@ class BSRMatrix : public CSRMatrix
void spmv_append_T(const double* x, double* b) const;
void spmv_append_neg(const double* x, double* b) const;
void spmv_append_neg_T(const double* x, double* b) const;
void spmv_residual(const double* x, const double* b, double* r) const;
void spmv_residual(const double* x, const double* b, double* r) const;

void add_append(BSRMatrix * A, BSRMatrix * C, bool remove_dup = true);

format_t format()
{
Expand Down
2 changes: 1 addition & 1 deletion raptor/ruge_stuben/par_air_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void ParAIRSolver::extend_hier(T & A) {
levels.back()->R = R;
levels.emplace_back(new ParLevel());
auto & level = *levels.back();
auto AP = A.mult(P, tap_level);
auto AP = A.mult(P);
auto coarse_A = R->mult(AP);

level.A = coarse_A;
Expand Down
92 changes: 60 additions & 32 deletions raptor/util/linalg/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,73 @@ CSRMatrix* CSRMatrix::add(CSRMatrix* B, bool remove_dup)
return C;
}

void CSRMatrix::add_append(CSRMatrix* B, CSRMatrix* C, bool remove_dup)
namespace impl {

template <class T> struct is_bsr : std::false_type {};
template <> struct is_bsr<BSRMatrix> : std::true_type {};
template <class T> inline constexpr bool is_bsr_v = is_bsr<T>::value;

template <class T, std::enable_if_t<std::is_same_v<T, CSRMatrix> ||
std::is_same_v<T, BSRMatrix>,
bool> = true>
void add_append(const T & A, const T & B, T & C, bool remove_dup)
{
int start, end;
auto vals = [](auto & mat) -> auto & {
if constexpr (is_bsr_v<T>) return mat.block_vals;
else return mat.vals;
};
C.resize(A.n_rows, A.n_cols);
int C_nnz = A.nnz + B.nnz;
C.idx2.resize(C_nnz);
vals(C).resize(C_nnz);

C->resize(n_rows, n_cols);
int C_nnz = nnz + B->nnz;
C->idx2.resize(C_nnz);
C->vals.resize(C_nnz);
auto copy_vals = [b_size = A.b_size](auto beg, auto end, auto out) {
(void) b_size;
if constexpr (is_bsr_v<T>) {
for (; beg != end; ++beg, ++out) {
auto val = new double[b_size];
std::copy(val, val + b_size, *beg);
*out = val;
}
} else std::copy(beg, end, out);
};

C_nnz = 0;
C->idx1[0] = 0;
for (int i = 0; i < n_rows; i++)
C.idx1[0] = 0;
for (int i = 0; i < A.n_rows; i++)
{
start = idx1[i];
end = idx1[i+1];
std::copy(idx2.begin() + start,
idx2.begin() + end,
C->idx2.begin() + C_nnz);
std::copy(vals.begin() + start,
vals.begin() + end,
C->vals.begin() + C_nnz);
C_nnz += (end - start);
auto add_row = [&](const T & src, T & dst) {
auto start = src.idx1[i];
auto end = src.idx1[i+1];
std::copy(src.idx2.begin() + start,
src.idx2.begin() + end,
dst.idx2.begin() + C_nnz);
copy_vals(vals(src).begin() + start,
vals(src).begin() + end,
vals(dst).begin() + C_nnz);
return end - start;
};

start = B->idx1[i];
end = B->idx1[i+1];
std::copy(B->idx2.begin() + start,
B->idx2.begin() + end,
C->idx2.begin() + C_nnz);
std::copy(B->vals.begin() + start,
B->vals.begin() + end,
C->vals.begin() + C_nnz);
C_nnz += (end - start);

C->idx1[i+1] = C_nnz;
C_nnz += add_row(A, C);
C_nnz += add_row(B, C);

C.idx1[i+1] = C_nnz;
}
C->nnz = C_nnz;
C->sort();
if (remove_dup)
C->remove_duplicates();
C.nnz = C_nnz;
C.sort();
if (remove_dup)
C.remove_duplicates();
}
}

void CSRMatrix::add_append(CSRMatrix* B, CSRMatrix* C, bool remove_dup)
{
impl::add_append(*this, *B, *C, remove_dup);
}

void BSRMatrix::add_append(BSRMatrix * B, BSRMatrix * C, bool remove_dup)
{
impl::add_append(*this, *B, *C, remove_dup);
}

CSRMatrix* CSRMatrix::subtract(CSRMatrix* B)
Expand Down
Loading

0 comments on commit 6f963f6

Please sign in to comment.