From e89bcfc0e86cd4952c03fdf920d11c598ae6e16a Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Tue, 3 Sep 2024 10:03:08 -0700 Subject: [PATCH] [SandboxIR] Add tracking for ShuffleVectorInst::commute. (#106644) Track it as an operand swap + a `setShuffleMask` and delegate to the `llvm::ShuffleVectorInst` implementation. --- llvm/include/llvm/SandboxIR/SandboxIR.h | 2 +- llvm/lib/SandboxIR/SandboxIR.cpp | 7 +++++++ llvm/unittests/SandboxIR/TrackerTest.cpp | 20 ++++++++++++++++---- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 2ed7243fa612f4..63c32da7cad20e 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -1143,7 +1143,7 @@ class ShuffleVectorInst final /// Swap the operands and adjust the mask to preserve the semantics of the /// instruction. - void commute() { cast(Val)->commute(); } + void commute(); /// Return true if a shufflevector instruction can be formed with the /// specified operands. diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 6bdc580f751d18..f95ced880d4cb2 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -2185,6 +2185,13 @@ VectorType *ShuffleVectorInst::getType() const { Ctx.getType(cast(Val)->getType())); } +void ShuffleVectorInst::commute() { + Ctx.getTracker().emplaceIfTracking(this); + Ctx.getTracker().emplaceIfTracking(getOperandUse(0), + getOperandUse(1)); + cast(Val)->commute(); +} + Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const { return Ctx.getOrCreateConstant( cast(Val)->getShuffleMaskForBitcode()); diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp index c189100fbd6947..fe29452a8aea20 100644 --- a/llvm/unittests/SandboxIR/TrackerTest.cpp +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -964,7 +964,7 @@ define void @foo(i32 %cond0, i32 %cond1) { EXPECT_EQ(Switch->findCaseDest(BB1), One); } -TEST_F(TrackerTest, ShuffleVectorInstSetters) { +TEST_F(TrackerTest, ShuffleVectorInst) { parseIR(C, R"IR( define void @foo(<2 x i8> %v1, <2 x i8> %v2) { %shuf = shufflevector <2 x i8> %v1, <2 x i8> %v2, <2 x i32> @@ -983,10 +983,22 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) { SmallVector OrigMask(SVI->getShuffleMask()); Ctx.save(); SVI->setShuffleMask(ArrayRef({0, 0})); - EXPECT_THAT(SVI->getShuffleMask(), - testing::Not(testing::ElementsAreArray(OrigMask))); + EXPECT_NE(SVI->getShuffleMask(), ArrayRef(OrigMask)); Ctx.revert(); - EXPECT_THAT(SVI->getShuffleMask(), testing::ElementsAreArray(OrigMask)); + EXPECT_EQ(SVI->getShuffleMask(), ArrayRef(OrigMask)); + + // Check commute. + auto *Op0 = SVI->getOperand(0); + auto *Op1 = SVI->getOperand(1); + Ctx.save(); + SVI->commute(); + EXPECT_EQ(SVI->getOperand(0), Op1); + EXPECT_EQ(SVI->getOperand(1), Op0); + EXPECT_NE(SVI->getShuffleMask(), ArrayRef(OrigMask)); + Ctx.revert(); + EXPECT_EQ(SVI->getOperand(0), Op0); + EXPECT_EQ(SVI->getOperand(1), Op1); + EXPECT_EQ(SVI->getShuffleMask(), ArrayRef(OrigMask)); } TEST_F(TrackerTest, PossiblyDisjointInstSetters) {