Skip to content

Commit

Permalink
fix detection to run ThreeElementVectorLoweringPass (#1367)
Browse files Browse the repository at this point in the history
* fix detection to run ThreeElementVectorLoweringPass

If a gep ends up trying to access the 4th element of a vec3, we need
to lower vec3 to vec4.

This is fixing a regression on CTS test_half
Ref kpet/clvk#698

* do not consider gep with only 1 indice
  • Loading branch information
rjodinchr authored May 31, 2024
1 parent 86dc06b commit 1035668
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 4 deletions.
28 changes: 26 additions & 2 deletions lib/ThreeElementVectorLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,22 +269,46 @@ bool clspv::ThreeElementVectorLoweringPass::vec3ShouldBeLowered(Module &M) {
return false;
default:
for (auto &F : M.functions()) {
if (vec3BitcastInFunction(F))
if (vec3ShouldBeLowered(F))
return true;
}
return false;
}
}

bool clspv::ThreeElementVectorLoweringPass::vec3BitcastInFunction(Function &F) {
bool clspv::ThreeElementVectorLoweringPass::vec3ShouldBeLowered(Function &F) {
for (Instruction &I : instructions(F)) {
if (haveImplicitCast(&I)) {
return true;
} else if (haveInvalidVec3GEP(&I)) {
return true;
}
}
return false;
}

bool clspv::ThreeElementVectorLoweringPass::haveInvalidVec3GEP(Value *Value) {
auto gep = dyn_cast<GetElementPtrInst>(Value);
if (!gep || gep->getNumIndices() <= 1) {
return false;
}

SmallVector<llvm::Value *> idxs(gep->idx_begin(), gep->idx_end() - 1);
auto last_type =
GetElementPtrInst::getIndexedType(gep->getSourceElementType(), idxs);
auto vec_type = dyn_cast<FixedVectorType>(last_type);
if (!vec_type || vec_type->getNumElements() != 3) {
return false;
}
auto last_idx = gep->getOperand(gep->getNumOperands() - 1);
auto cst_idx = dyn_cast<ConstantInt>(last_idx);
if (!cst_idx || cst_idx->getZExtValue() >= 3) {
return true;
}

return false;
}

bool clspv::ThreeElementVectorLoweringPass::haveImplicitCast(Value *Value) {
Type *source_ty = nullptr;
Type *dest_ty = nullptr;
Expand Down
9 changes: 7 additions & 2 deletions lib/ThreeElementVectorLoweringPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ struct ThreeElementVectorLoweringPass
private:
// High-level implementation details of runOnModule.

/// Look for bitcast of vec3 inside the function
bool vec3BitcastInFunction(llvm::Function &F);
/// Look for vec3 patterns inside the function that requires lowering vec3 to
/// vec4.
bool vec3ShouldBeLowered(llvm::Function &F);

/// Returns whether the vec3 should be transformed into vec4
bool vec3ShouldBeLowered(llvm::Module &M);
Expand All @@ -114,6 +115,10 @@ struct ThreeElementVectorLoweringPass
/// opaque pointers
bool haveImplicitCast(llvm::Value *Value);

// Returns true if the type before last indice is a vec3 and last indice is
// not constant or bigger or equal to 3.
bool haveInvalidVec3GEP(llvm::Value *Value);

/// Lower all global variables in the module.
bool runOnGlobals(llvm::Module &M);

Expand Down
13 changes: 13 additions & 0 deletions test/ThreeElementVectorLowering/invalid_vec3_gep.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
; RUN: clspv-opt %s -o %t.ll --passes=three-element-vector-lowering
; RUN: FileCheck %s < %t.ll

; CHECK: getelementptr inbounds <4 x i32>, ptr %a, i32 2, i32 3

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

define dso_local spir_kernel void @test1(ptr %a) {
entry:
%gep = getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 3
ret void
}
13 changes: 13 additions & 0 deletions test/ThreeElementVectorLowering/invalid_vec3_gep2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
; RUN: clspv-opt %s -o %t.ll --passes=three-element-vector-lowering
; RUN: FileCheck %s < %t.ll

; CHECK: getelementptr inbounds <4 x i32>, ptr %a, i32 2, i32 %i

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

define dso_local spir_kernel void @test1(ptr %a, i32 %i) {
entry:
%gep = getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 %i
ret void
}
13 changes: 13 additions & 0 deletions test/ThreeElementVectorLowering/valid_vec3_gep.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
; RUN: clspv-opt %s -o %t.ll --passes=three-element-vector-lowering
; RUN: FileCheck %s < %t.ll

; CHECK: getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 2

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

define dso_local spir_kernel void @test1(ptr %a) {
entry:
%gep = getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 2
ret void
}

0 comments on commit 1035668

Please sign in to comment.