Skip to content

Commit

Permalink
Store argPtrs of cuda kernels in a std array instead of a vector
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 authored and MihailMihov committed Oct 27, 2024
1 parent 781c08e commit 4dc99c9
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ CUDA_HOST_DEVICE T push(tape<T>& to, ArgsT... val) {
#if defined(__CUDACC__) && !defined(__CUDA_ARCH__)
if (CUDAkernel) {
constexpr size_t totalArgs = sizeof...(args) + sizeof...(Rest);
std::vector<void*> argPtrs;
argPtrs.reserve(totalArgs);
(argPtrs.push_back(static_cast<void*>(&args)), ...);
std::array<void*, totalArgs> argPtrs = {static_cast<void*>(&args)...,
static_cast<Rest>(nullptr)...};

void* null_param = nullptr;
for (size_t i = sizeof...(args); i < totalArgs; ++i)
Expand Down

0 comments on commit 4dc99c9

Please sign in to comment.