Skip to content

Commit

Permalink
Fix bcast function for dense format with magma
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanlucf22 authored and nicolasbock committed Feb 4, 2022
1 parent 38dd02c commit 3529cda
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
25 changes: 18 additions & 7 deletions src/C-interface/dense/bml_parallel_dense.c
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,24 @@ bml_mpi_bcast_matrix_dense(
const int root,
MPI_Comm comm)
{
// create MPI data type to avoid multiple messages
MPI_Datatype mpi_data_type;
bml_mpi_type_create_struct_dense(A, &mpi_data_type);

MPI_Bcast(A, 1, mpi_data_type, root, comm);

MPI_Type_free(&mpi_data_type);
switch (A->matrix_precision)
{
case single_real:
return bml_mpi_bcast_matrix_dense_single_real(A, root, comm);
break;
case double_real:
return bml_mpi_bcast_matrix_dense_double_real(A, root, comm);
break;
case single_complex:
return bml_mpi_bcast_matrix_dense_single_complex(A, root, comm);
break;
case double_complex:
return bml_mpi_bcast_matrix_dense_double_complex(A, root, comm);
break;
default:
LOG_ERROR("unknown precision\n");
break;
}
}

#endif
17 changes: 17 additions & 0 deletions src/C-interface/dense/bml_parallel_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ void bml_mpi_bcast_matrix_dense(
bml_matrix_dense_t * A,
const int root,
MPI_Comm comm);

void bml_mpi_bcast_matrix_dense_single_real(
bml_matrix_dense_t * A,
const int root,
MPI_Comm comm);
void bml_mpi_bcast_matrix_dense_double_real(
bml_matrix_dense_t * A,
const int root,
MPI_Comm comm);
void bml_mpi_bcast_matrix_dense_single_complex(
bml_matrix_dense_t * A,
const int root,
MPI_Comm comm);
void bml_mpi_bcast_matrix_dense_double_complex(
bml_matrix_dense_t * A,
const int root,
MPI_Comm comm);
#endif

#endif
28 changes: 23 additions & 5 deletions src/C-interface/dense/bml_parallel_dense_typed.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
#include <omp.h>
#endif

#ifdef BML_USE_MAGMA
// buffer on CPU to be used for communications
static MAGMA_T *A_matrix_buffer;
#endif

/** Gather a bml matrix across MPI ranks.
*
* \ingroup parallel_group
Expand Down Expand Up @@ -164,4 +159,27 @@ bml_matrix_dense_t
return A_bml;
}

void TYPED_FUNC(
bml_mpi_bcast_matrix_dense) (
bml_matrix_dense_t * A,
const int root,
MPI_Comm comm)
{
#ifdef BML_USE_MAGMA
MAGMA_T *A_matrix = bml_allocate_memory(sizeof(MAGMA_T) * A->N * A->N);
MAGMA(getmatrix) (A->N, A->N, A->matrix, A->ld, A_matrix, A->N,
bml_queue());
#else
REAL_T *A_matrix = A->matrix;
#endif

MPI_Bcast(A_matrix, A->N * A->N, MPI_T, root, comm);

#ifdef BML_USE_MAGMA
MAGMA(setmatrix) (A->N, A->N, A_matrix, A->N, A->matrix, A->ld,
bml_queue());
bml_free_memory(A_matrix);
#endif
}

#endif

0 comments on commit 3529cda

Please sign in to comment.