From 103566846b0f211513c962c5c6df3874435685cb Mon Sep 17 00:00:00 2001 From: Romaric Jodin <89833130+rjodinchr@users.noreply.github.com> Date: Fri, 31 May 2024 15:38:06 +0200 Subject: [PATCH] fix detection to run ThreeElementVectorLoweringPass (#1367) * fix detection to run ThreeElementVectorLoweringPass If a gep ends up trying to access the 4th element of a vec3, we need to lower vec3 to vec4. This is fixing a regression on CTS test_half Ref https://github.com/kpet/clvk/pull/698 * do not consider gep with only 1 indice --- lib/ThreeElementVectorLoweringPass.cpp | 28 +++++++++++++++++-- lib/ThreeElementVectorLoweringPass.h | 9 ++++-- .../invalid_vec3_gep.ll | 13 +++++++++ .../invalid_vec3_gep2.ll | 13 +++++++++ .../valid_vec3_gep.ll | 13 +++++++++ 5 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 test/ThreeElementVectorLowering/invalid_vec3_gep.ll create mode 100644 test/ThreeElementVectorLowering/invalid_vec3_gep2.ll create mode 100644 test/ThreeElementVectorLowering/valid_vec3_gep.ll diff --git a/lib/ThreeElementVectorLoweringPass.cpp b/lib/ThreeElementVectorLoweringPass.cpp index 1038de487..91b18a8e6 100644 --- a/lib/ThreeElementVectorLoweringPass.cpp +++ b/lib/ThreeElementVectorLoweringPass.cpp @@ -269,22 +269,46 @@ bool clspv::ThreeElementVectorLoweringPass::vec3ShouldBeLowered(Module &M) { return false; default: for (auto &F : M.functions()) { - if (vec3BitcastInFunction(F)) + if (vec3ShouldBeLowered(F)) return true; } return false; } } -bool clspv::ThreeElementVectorLoweringPass::vec3BitcastInFunction(Function &F) { +bool clspv::ThreeElementVectorLoweringPass::vec3ShouldBeLowered(Function &F) { for (Instruction &I : instructions(F)) { if (haveImplicitCast(&I)) { return true; + } else if (haveInvalidVec3GEP(&I)) { + return true; } } return false; } +bool clspv::ThreeElementVectorLoweringPass::haveInvalidVec3GEP(Value *Value) { + auto gep = dyn_cast(Value); + if (!gep || gep->getNumIndices() <= 1) { + return false; + } + + SmallVector idxs(gep->idx_begin(), gep->idx_end() - 1); + auto last_type = + GetElementPtrInst::getIndexedType(gep->getSourceElementType(), idxs); + auto vec_type = dyn_cast(last_type); + if (!vec_type || vec_type->getNumElements() != 3) { + return false; + } + auto last_idx = gep->getOperand(gep->getNumOperands() - 1); + auto cst_idx = dyn_cast(last_idx); + if (!cst_idx || cst_idx->getZExtValue() >= 3) { + return true; + } + + return false; +} + bool clspv::ThreeElementVectorLoweringPass::haveImplicitCast(Value *Value) { Type *source_ty = nullptr; Type *dest_ty = nullptr; diff --git a/lib/ThreeElementVectorLoweringPass.h b/lib/ThreeElementVectorLoweringPass.h index 61069ed38..8372569bc 100644 --- a/lib/ThreeElementVectorLoweringPass.h +++ b/lib/ThreeElementVectorLoweringPass.h @@ -104,8 +104,9 @@ struct ThreeElementVectorLoweringPass private: // High-level implementation details of runOnModule. - /// Look for bitcast of vec3 inside the function - bool vec3BitcastInFunction(llvm::Function &F); + /// Look for vec3 patterns inside the function that requires lowering vec3 to + /// vec4. + bool vec3ShouldBeLowered(llvm::Function &F); /// Returns whether the vec3 should be transformed into vec4 bool vec3ShouldBeLowered(llvm::Module &M); @@ -114,6 +115,10 @@ struct ThreeElementVectorLoweringPass /// opaque pointers bool haveImplicitCast(llvm::Value *Value); + // Returns true if the type before last indice is a vec3 and last indice is + // not constant or bigger or equal to 3. + bool haveInvalidVec3GEP(llvm::Value *Value); + /// Lower all global variables in the module. bool runOnGlobals(llvm::Module &M); diff --git a/test/ThreeElementVectorLowering/invalid_vec3_gep.ll b/test/ThreeElementVectorLowering/invalid_vec3_gep.ll new file mode 100644 index 000000000..2602ef47e --- /dev/null +++ b/test/ThreeElementVectorLowering/invalid_vec3_gep.ll @@ -0,0 +1,13 @@ +; RUN: clspv-opt %s -o %t.ll --passes=three-element-vector-lowering +; RUN: FileCheck %s < %t.ll + +; CHECK: getelementptr inbounds <4 x i32>, ptr %a, i32 2, i32 3 + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +define dso_local spir_kernel void @test1(ptr %a) { +entry: + %gep = getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 3 + ret void +} diff --git a/test/ThreeElementVectorLowering/invalid_vec3_gep2.ll b/test/ThreeElementVectorLowering/invalid_vec3_gep2.ll new file mode 100644 index 000000000..11aeeb732 --- /dev/null +++ b/test/ThreeElementVectorLowering/invalid_vec3_gep2.ll @@ -0,0 +1,13 @@ +; RUN: clspv-opt %s -o %t.ll --passes=three-element-vector-lowering +; RUN: FileCheck %s < %t.ll + +; CHECK: getelementptr inbounds <4 x i32>, ptr %a, i32 2, i32 %i + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +define dso_local spir_kernel void @test1(ptr %a, i32 %i) { +entry: + %gep = getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 %i + ret void +} diff --git a/test/ThreeElementVectorLowering/valid_vec3_gep.ll b/test/ThreeElementVectorLowering/valid_vec3_gep.ll new file mode 100644 index 000000000..995fac0ff --- /dev/null +++ b/test/ThreeElementVectorLowering/valid_vec3_gep.ll @@ -0,0 +1,13 @@ +; RUN: clspv-opt %s -o %t.ll --passes=three-element-vector-lowering +; RUN: FileCheck %s < %t.ll + +; CHECK: getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 2 + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +define dso_local spir_kernel void @test1(ptr %a) { +entry: + %gep = getelementptr inbounds <3 x i32>, ptr %a, i32 2, i32 2 + ret void +}