diff --git a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp index dfcd1dab..e8fa2440 100644 --- a/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp +++ b/lib/gc/Transforms/GPU/GpuToGpuOcl.cpp @@ -415,11 +415,30 @@ struct ConvertLaunch final : ConvertOpPattern { blockSize.emplace_back(gpuLaunch.getBlockSizeY()); blockSize.emplace_back(gpuLaunch.getBlockSizeZ()); - for (auto arg : adaptor.getKernelOperands()) { + for (auto arg : gpuLaunch.getKernelOperands()) { auto type = arg.getType(); - // Assuming, that the value is either an integer or a float or a pointer. - // In the latter case, the size is 0 bytes. - auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0; + size_t size; + if (isa(type)) { + size = 0; // A special case for pointers + } else if (type.isIndex()) { + size = helper.idxType.getIntOrFloatBitWidth() / 8; + } else if (type.isIntOrFloat()) { + size = type.getIntOrFloatBitWidth() / 8; + } else if (auto vectorType = dyn_cast(type)) { + type = vectorType.getElementType(); + if (type.isIntOrFloat()) { + size = type.getIntOrFloatBitWidth(); + } else if (type.isIndex()) { + size = helper.idxType.getIntOrFloatBitWidth(); + } else { + llvm::errs() << "Unsupported vector element type: " << type << "\n"; + return false; + } + size *= vectorType.getNumElements() / 8; + } else { + llvm::errs() << "Unsupported type: " << type << "\n"; + return false; + } argSize.emplace_back(helper.idxConstant(rewriter, loc, size)); }