Skip to content

Commit

Permalink
Added support for vector types
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Sep 20, 2024
1 parent 34415b1 commit e3ceb4f
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions lib/gc/Transforms/GPU/GpuToGpuOcl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,30 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
blockSize.emplace_back(gpuLaunch.getBlockSizeY());
blockSize.emplace_back(gpuLaunch.getBlockSizeZ());

for (auto arg : adaptor.getKernelOperands()) {
for (auto arg : gpuLaunch.getKernelOperands()) {
auto type = arg.getType();
// Assuming, that the value is either an integer or a float or a pointer.
// In the latter case, the size is 0 bytes.
auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0;
size_t size;
if (isa<MemRefType>(type)) {
size = 0; // A special case for pointers
} else if (type.isIndex()) {
size = helper.idxType.getIntOrFloatBitWidth() / 8;
} else if (type.isIntOrFloat()) {
size = type.getIntOrFloatBitWidth() / 8;
} else if (auto vectorType = dyn_cast<VectorType>(type)) {
type = vectorType.getElementType();
if (type.isIntOrFloat()) {
size = type.getIntOrFloatBitWidth();
} else if (type.isIndex()) {
size = helper.idxType.getIntOrFloatBitWidth();
} else {
llvm::errs() << "Unsupported vector element type: " << type << "\n";
return false;
}
size *= vectorType.getNumElements() / 8;
} else {
llvm::errs() << "Unsupported type: " << type << "\n";
return false;
}
argSize.emplace_back(helper.idxConstant(rewriter, loc, size));
}

Expand Down

0 comments on commit e3ceb4f

Please sign in to comment.