Skip to content

Commit

Permalink
[SandboxIR] Implement FixedVectorType (#107930)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sterling-Augustine authored Sep 10, 2024
1 parent d14a600 commit bb72865
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
46 changes: 46 additions & 0 deletions llvm/include/llvm/SandboxIR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Context;
// Forward declare friend classes for MSVC.
class PointerType;
class VectorType;
class FixedVectorType;
class IntegerType;
class FunctionType;
class ArrayType;
Expand All @@ -41,6 +42,7 @@ class Type {
friend class ArrayType; // For LLVMTy.
friend class StructType; // For LLVMTy.
friend class VectorType; // For LLVMTy.
friend class FixedVectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
friend class IntegerType; // For LLVMTy.
Expand Down Expand Up @@ -344,6 +346,50 @@ class VectorType : public Type {
}
};

class FixedVectorType : public VectorType {
public:
static FixedVectorType *get(Type *ElementType, unsigned NumElts);

static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) {
return get(ElementType, FVTy->getNumElements());
}

static FixedVectorType *getInteger(FixedVectorType *VTy) {
return cast<FixedVectorType>(VectorType::getInteger(VTy));
}

static FixedVectorType *getExtendedElementVectorType(FixedVectorType *VTy) {
return cast<FixedVectorType>(VectorType::getExtendedElementVectorType(VTy));
}

static FixedVectorType *getTruncatedElementVectorType(FixedVectorType *VTy) {
return cast<FixedVectorType>(
VectorType::getTruncatedElementVectorType(VTy));
}

static FixedVectorType *getSubdividedVectorType(FixedVectorType *VTy,
int NumSubdivs) {
return cast<FixedVectorType>(
VectorType::getSubdividedVectorType(VTy, NumSubdivs));
}

static FixedVectorType *getHalfElementsVectorType(FixedVectorType *VTy) {
return cast<FixedVectorType>(VectorType::getHalfElementsVectorType(VTy));
}

static FixedVectorType *getDoubleElementsVectorType(FixedVectorType *VTy) {
return cast<FixedVectorType>(VectorType::getDoubleElementsVectorType(VTy));
}

static bool classof(const Type *T) {
return isa<llvm::FixedVectorType>(T->LLVMTy);
}

unsigned getNumElements() const {
return cast<llvm::FixedVectorType>(LLVMTy)->getNumElements();
}
};

class FunctionType : public Type {
public:
// TODO: add missing functions
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/SandboxIR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ bool VectorType::isValidElementType(Type *ElemTy) {
return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
}

FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
return cast<FixedVectorType>(ElementType->getContext().getType(
llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
}

IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
Expand Down
58 changes: 58 additions & 0 deletions llvm/unittests/SandboxIR/TypesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,64 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
}

TEST_F(SandboxTypeTest, FixedVectorType) {
parseIR(C, R"IR(
define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
// Check classof(), creation, accessors
auto *Vec4i16Ty = cast<sandboxir::FixedVectorType>(F->getArg(0)->getType());
EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
EXPECT_EQ(Vec4i16Ty->getElementCount(), ElementCount::getFixed(4));

// get(ElementType, NumElements)
EXPECT_EQ(
sandboxir::FixedVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
F->getArg(0)->getType());
// get(ElementType, Other)
EXPECT_EQ(sandboxir::FixedVectorType::get(
sandboxir::Type::getInt16Ty(Ctx),
cast<sandboxir::FixedVectorType>(F->getArg(0)->getType())),
F->getArg(0)->getType());
auto *Vec4FTy = cast<sandboxir::FixedVectorType>(F->getArg(1)->getType());
EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
// getInteger
auto *Vec4i32Ty = sandboxir::FixedVectorType::getInteger(Vec4FTy);
EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
EXPECT_EQ(Vec4i32Ty->getElementCount(), Vec4FTy->getElementCount());
// getExtendedElementCountVectorType
auto *Vec4i64Ty =
sandboxir::FixedVectorType::getExtendedElementVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
EXPECT_EQ(Vec4i64Ty->getElementCount(), Vec4i16Ty->getElementCount());
// getTruncatedElementVectorType
auto *Vec4i8Ty =
sandboxir::FixedVectorType::getTruncatedElementVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
EXPECT_EQ(Vec4i8Ty->getElementCount(), Vec4i8Ty->getElementCount());
// getSubdividedVectorType
auto *Vec8i8Ty =
sandboxir::FixedVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
EXPECT_EQ(Vec8i8Ty->getElementCount(), ElementCount::getFixed(8));
// getNumElements
EXPECT_EQ(Vec8i8Ty->getNumElements(), 8u);
// getHalfElementsVectorType
auto *Vec2i16Ty =
sandboxir::FixedVectorType::getHalfElementsVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
EXPECT_EQ(Vec2i16Ty->getElementCount(), ElementCount::getFixed(2));
// getDoubleElementsVectorType
auto *Vec8i16Ty =
sandboxir::FixedVectorType::getDoubleElementsVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
}

TEST_F(SandboxTypeTest, FunctionType) {
parseIR(C, R"IR(
define void @foo() {
Expand Down

0 comments on commit bb72865

Please sign in to comment.