From 91db28b64bb1dfb76c5e621f3aaabe783b285f77 Mon Sep 17 00:00:00 2001 From: Zach Atkins Date: Fri, 22 Dec 2023 12:09:32 -0700 Subject: [PATCH] GPU - Don't call diagonal kernels if input arrays are null pointers --- backends/cuda-ref/ceed-cuda-ref-operator.c | 39 ++++++++++---------- backends/hip-ref/ceed-hip-ref-operator.c | 41 ++++++++++++---------- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/backends/cuda-ref/ceed-cuda-ref-operator.c b/backends/cuda-ref/ceed-cuda-ref-operator.c index f47c52b7a4..3e6f6b6ea3 100644 --- a/backends/cuda-ref/ceed-cuda-ref-operator.c +++ b/backends/cuda-ref/ceed-cuda-ref-operator.c @@ -809,25 +809,28 @@ static inline int CeedOperatorAssembleDiagonalCore_Cuda(CeedOperator op, CeedVec } CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); - // Assemble element operator diagonals - CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); - CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); - CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); - - // Compute the diagonal of B^T D B - int elem_per_block = 1; - int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0); - void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out, - &diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array}; - if (is_point_block) { - CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearPointBlock, grid, diag->num_nodes, 1, elem_per_block, args)); - } else { - CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearDiagonal, grid, diag->num_nodes, 1, elem_per_block, args)); - } + // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers + if (diag->num_nodes > 0) { + // Assemble element operator diagonals + CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); + CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); + CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); + + // Compute the diagonal of B^T D B + int elem_per_block = 1; + int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0); + void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out, + &diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array}; + if (is_point_block) { + CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearPointBlock, grid, diag->num_nodes, 1, elem_per_block, args)); + } else { + CeedCallBackend(CeedRunKernelDim_Cuda(ceed, diag->linearDiagonal, grid, diag->num_nodes, 1, elem_per_block, args)); + } - // Restore arrays - CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); - CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); + // Restore arrays + CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); + CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); + } // Assemble local operator diagonal CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request)); diff --git a/backends/hip-ref/ceed-hip-ref-operator.c b/backends/hip-ref/ceed-hip-ref-operator.c index 486269bb9d..6205109df3 100644 --- a/backends/hip-ref/ceed-hip-ref-operator.c +++ b/backends/hip-ref/ceed-hip-ref-operator.c @@ -812,26 +812,29 @@ static inline int CeedOperatorAssembleDiagonalCore_Hip(CeedOperator op, CeedVect } CeedCallBackend(CeedVectorSetValue(elem_diag, 0.0)); - // Assemble element operator diagonals - CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); - CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); - CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); - - // Compute the diagonal of B^T D B - int elem_per_block = 1; - int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0); - void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out, - &diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array}; - - if (is_point_block) { - CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearPointBlock, grid, diag->num_modes, 1, elem_per_block, args)); - } else { - CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearDiagonal, grid, diag->num_modes, 1, elem_per_block, args)); - } + // Only assemble diagonal if the basis has nodes, otherwise inputs are null pointers + if (diag->num_modes > 0) { + // Assemble element operator diagonals + CeedCallBackend(CeedVectorGetArray(elem_diag, CEED_MEM_DEVICE, &elem_diag_array)); + CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array)); + CeedCallBackend(CeedElemRestrictionGetNumElements(diag_rstr, &num_elem)); + + // Compute the diagonal of B^T D B + int elem_per_block = 1; + int grid = num_elem / elem_per_block + ((num_elem / elem_per_block * elem_per_block < num_elem) ? 1 : 0); + void *args[] = {(void *)&num_elem, &diag->d_identity, &diag->d_interp_in, &diag->d_grad_in, &diag->d_interp_out, + &diag->d_grad_out, &diag->d_e_mode_in, &diag->d_e_mode_out, &assembled_qf_array, &elem_diag_array}; + + if (is_point_block) { + CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearPointBlock, grid, diag->num_modes, 1, elem_per_block, args)); + } else { + CeedCallBackend(CeedRunKernelDim_Hip(ceed, diag->linearDiagonal, grid, diag->num_modes, 1, elem_per_block, args)); + } - // Restore arrays - CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); - CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); + // Restore arrays + CeedCallBackend(CeedVectorRestoreArray(elem_diag, &elem_diag_array)); + CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array)); + } // Assemble local operator diagonal CeedCallBackend(CeedElemRestrictionApply(diag_rstr, CEED_TRANSPOSE, elem_diag, assembled, request));