Skip to content

Commit

Permalink
[NVPTX] Implement variadic functions using IR lowering (llvm#96015)
Browse files Browse the repository at this point in the history
Summary:
This patch implements support for variadic functions for NVPTX targets.
The implementation here mainly follows what was done to implement it for
AMDGPU in llvm#93362.

We change the NVPTX codegen to lower all variadic arguments to functions
by-value. This creates a flattened set of arguments that the IR lowering
pass converts into a struct with the proper alignment.

The behavior of this function was determined by iteratively checking
what the NVCC copmiler generates for its output. See examples like
https://godbolt.org/z/KavfTGY93. I have noted the main methods that
NVIDIA uses to lower variadic functions.

1. All arguments are passed in a pointer to aggregate.
2. The minimum alignment for a plain argument is 4 bytes.
3. Alignment is dictated by the underlying type
4. Structs are flattened and do not have their alignment changed.
5. NVPTX never passes any arguments indirectly, even very large ones.

This patch passes the tests in the `libc` project currently, including
support for `sprintf`.
  • Loading branch information
jhuber6 authored and aaryanshukla committed Jul 16, 2024
1 parent 3cf142a commit 3517a11
Show file tree
Hide file tree
Showing 9 changed files with 930 additions and 32 deletions.
3 changes: 1 addition & 2 deletions clang/lib/Basic/Targets/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
}

BuiltinVaListKind getBuiltinVaListKind() const override {
// FIXME: implement
return TargetInfo::CharPtrBuiltinVaList;
return TargetInfo::VoidPtrBuiltinVaList;
}

bool isValidCPUName(StringRef Name) const override {
Expand Down
12 changes: 9 additions & 3 deletions clang/lib/CodeGen/Targets/NVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,11 @@ ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {
if (!getCXXABI().classifyReturnType(FI))
FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
for (auto &I : FI.arguments())
I.info = classifyArgumentType(I.type);

for (auto &&[ArgumentsCount, I] : llvm::enumerate(FI.arguments()))
I.info = ArgumentsCount < FI.getNumRequiredArgs()
? classifyArgumentType(I.type)
: ABIArgInfo::getDirect();

// Always honor user-specified calling convention.
if (FI.getCallingConvention() != llvm::CallingConv::C)
Expand All @@ -215,7 +218,10 @@ void NVPTXABIInfo::computeInfo(CGFunctionInfo &FI) const {

RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
QualType Ty, AggValueSlot Slot) const {
llvm_unreachable("NVPTX does not support varargs");
return emitVoidPtrVAArg(CGF, VAListAddr, Ty, /*IsIndirect=*/false,
getContext().getTypeInfoInChars(Ty),
CharUnits::fromQuantity(1),
/*AllowHigherAlign=*/true, Slot);
}

void NVPTXTargetCodeGenInfo::setTargetAttributes(
Expand Down
94 changes: 94 additions & 0 deletions clang/test/CodeGen/variadic-nvptx.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -emit-llvm -o - %s | FileCheck %s

extern void varargs_simple(int, ...);

// CHECK-LABEL: define dso_local void @foo(
// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[C:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[S:%.*]] = alloca i16, align 2
// CHECK-NEXT: [[I:%.*]] = alloca i32, align 4
// CHECK-NEXT: [[L:%.*]] = alloca i64, align 8
// CHECK-NEXT: [[F:%.*]] = alloca float, align 4
// CHECK-NEXT: [[D:%.*]] = alloca double, align 8
// CHECK-NEXT: [[A:%.*]] = alloca [[STRUCT_ANON:%.*]], align 4
// CHECK-NEXT: [[V:%.*]] = alloca <4 x i32>, align 16
// CHECK-NEXT: [[T:%.*]] = alloca [[STRUCT_ANON_0:%.*]], align 1
// CHECK-NEXT: store i8 1, ptr [[C]], align 1
// CHECK-NEXT: store i16 1, ptr [[S]], align 2
// CHECK-NEXT: store i32 1, ptr [[I]], align 4
// CHECK-NEXT: store i64 1, ptr [[L]], align 8
// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 4
// CHECK-NEXT: store double 1.000000e+00, ptr [[D]], align 8
// CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[C]], align 1
// CHECK-NEXT: [[CONV:%.*]] = sext i8 [[TMP0]] to i32
// CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr [[S]], align 2
// CHECK-NEXT: [[CONV1:%.*]] = sext i16 [[TMP1]] to i32
// CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[I]], align 4
// CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr [[L]], align 8
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[F]], align 4
// CHECK-NEXT: [[CONV2:%.*]] = fpext float [[TMP4]] to double
// CHECK-NEXT: [[TMP5:%.*]] = load double, ptr [[D]], align 8
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 noundef [[CONV]], i32 noundef [[CONV1]], i32 noundef [[TMP2]], i64 noundef [[TMP3]], double noundef [[CONV2]], double noundef [[TMP5]])
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[A]], ptr align 4 @__const.foo.a, i64 12, i1 false)
// CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 0
// CHECK-NEXT: [[TMP7:%.*]] = load i32, ptr [[TMP6]], align 4
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 1
// CHECK-NEXT: [[TMP9:%.*]] = load i8, ptr [[TMP8]], align 4
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT_ANON]], ptr [[A]], i32 0, i32 2
// CHECK-NEXT: [[TMP11:%.*]] = load i32, ptr [[TMP10]], align 4
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i32 [[TMP7]], i8 [[TMP9]], i32 [[TMP11]])
// CHECK-NEXT: store <4 x i32> <i32 1, i32 1, i32 1, i32 1>, ptr [[V]], align 16
// CHECK-NEXT: [[TMP12:%.*]] = load <4 x i32>, ptr [[V]], align 16
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, <4 x i32> noundef [[TMP12]])
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
// CHECK-NEXT: [[TMP14:%.*]] = load i8, ptr [[TMP13]], align 1
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
// CHECK-NEXT: [[TMP16:%.*]] = load i8, ptr [[TMP15]], align 1
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
// CHECK-NEXT: [[TMP18:%.*]] = load i8, ptr [[TMP17]], align 1
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
// CHECK-NEXT: [[TMP20:%.*]] = load i8, ptr [[TMP19]], align 1
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 0
// CHECK-NEXT: [[TMP22:%.*]] = load i8, ptr [[TMP21]], align 1
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [[STRUCT_ANON_0]], ptr [[T]], i32 0, i32 1
// CHECK-NEXT: [[TMP24:%.*]] = load i8, ptr [[TMP23]], align 1
// CHECK-NEXT: call void (i32, ...) @varargs_simple(i32 noundef 0, i8 [[TMP14]], i8 [[TMP16]], i8 [[TMP18]], i8 [[TMP20]], i32 noundef 0, i8 [[TMP22]], i8 [[TMP24]])
// CHECK-NEXT: ret void
//
void foo() {
char c = '\x1';
short s = 1;
int i = 1;
long l = 1;
float f = 1.f;
double d = 1.;
varargs_simple(0, c, s, i, l, f, d);

struct {int x; char c; int y;} a = {1, '\x1', 1};
varargs_simple(0, a);

typedef int __attribute__((ext_vector_type(4))) int4;
int4 v = {1, 1, 1, 1};
varargs_simple(0, v);

struct {char c, d;} t;
varargs_simple(0, t, t, 0, t);
}

typedef struct {long x; long y;} S;
extern void varargs_complex(S, S, ...);

// CHECK-LABEL: define dso_local void @bar(
// CHECK-SAME: ) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 8
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[S]], ptr align 8 @__const.bar.s, i64 16, i1 false)
// CHECK-NEXT: call void (ptr, ptr, ...) @varargs_complex(ptr noundef byval([[STRUCT_S]]) align 8 [[S]], ptr noundef byval([[STRUCT_S]]) align 8 [[S]], i32 noundef 1, i64 noundef 1, double noundef 1.000000e+00)
// CHECK-NEXT: ret void
//
void bar() {
S s = {1l, 1l};
varargs_complex(s, s, 1, 1l, 1.0);
}
15 changes: 4 additions & 11 deletions libc/config/gpu/entrypoints.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
if(LIBC_TARGET_ARCHITECTURE_IS_AMDGPU)
set(extra_entrypoints
# stdio.h entrypoints
libc.src.stdio.snprintf
libc.src.stdio.sprintf
libc.src.stdio.vsnprintf
libc.src.stdio.vsprintf
)
endif()

set(TARGET_LIBC_ENTRYPOINTS
# assert.h entrypoints
libc.src.assert.__assert_fail
Expand Down Expand Up @@ -185,9 +175,12 @@ set(TARGET_LIBC_ENTRYPOINTS
libc.src.errno.errno

# stdio.h entrypoints
${extra_entrypoints}
libc.src.stdio.clearerr
libc.src.stdio.fclose
libc.src.stdio.sprintf
libc.src.stdio.snprintf
libc.src.stdio.vsprintf
libc.src.stdio.vsnprintf
libc.src.stdio.feof
libc.src.stdio.ferror
libc.src.stdio.fflush
Expand Down
21 changes: 9 additions & 12 deletions libc/test/src/__support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,15 @@ add_libc_test(
libc.src.__support.uint128
)

# NVPTX does not support varargs currently.
if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
add_libc_test(
arg_list_test
SUITE
libc-support-tests
SRCS
arg_list_test.cpp
DEPENDS
libc.src.__support.arg_list
)
endif()
add_libc_test(
arg_list_test
SUITE
libc-support-tests
SRCS
arg_list_test.cpp
DEPENDS
libc.src.__support.arg_list
)

if(NOT LIBC_TARGET_ARCHITECTURE_IS_NVPTX)
add_libc_test(
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/IPO/ExpandVariadics.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
Expand Down Expand Up @@ -342,6 +343,7 @@ void NVPTXPassConfig::addIRPasses() {
}

addPass(createAtomicExpandLegacyPass());
addPass(createExpandVariadicsPass(ExpandVariadicsMode::Lowering));
addPass(createNVPTXCtorDtorLoweringLegacyPass());

// === LSR and other generic IR passes ===
Expand Down
40 changes: 36 additions & 4 deletions llvm/lib/Transforms/IPO/ExpandVariadics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ bool ExpandVariadics::runOnFunction(Module &M, IRBuilder<> &Builder,
// Replace known calls to the variadic with calls to the va_list equivalent
for (User *U : make_early_inc_range(VariadicWrapper->users())) {
if (CallBase *CB = dyn_cast<CallBase>(U)) {
Value *calledOperand = CB->getCalledOperand();
if (VariadicWrapper == calledOperand)
Value *CalledOperand = CB->getCalledOperand();
if (VariadicWrapper == CalledOperand)
Changed |=
expandCall(M, Builder, CB, VariadicWrapper->getFunctionType(),
FixedArityReplacement);
Expand Down Expand Up @@ -938,6 +938,33 @@ struct Amdgpu final : public VariadicABIInfo {
}
};

struct NVPTX final : public VariadicABIInfo {

bool enableForTarget() override { return true; }

bool vaListPassedInSSARegister() override { return true; }

Type *vaListType(LLVMContext &Ctx) override {
return PointerType::getUnqual(Ctx);
}

Type *vaListParameterType(Module &M) override {
return PointerType::getUnqual(M.getContext());
}

Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder,
AllocaInst *, Value *Buffer) override {
return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M));
}

VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override {
// NVPTX expects natural alignment in all cases. The variadic call ABI will
// handle promoting types to their appropriate size and alignment.
Align A = DL.getABITypeAlign(Parameter);
return {A, false};
}
};

struct Wasm final : public VariadicABIInfo {

bool enableForTarget() override {
Expand Down Expand Up @@ -967,8 +994,8 @@ struct Wasm final : public VariadicABIInfo {
if (A < MinAlign)
A = Align(MinAlign);

if (auto s = dyn_cast<StructType>(Parameter)) {
if (s->getNumElements() > 1) {
if (auto *S = dyn_cast<StructType>(Parameter)) {
if (S->getNumElements() > 1) {
return {DL.getABITypeAlign(PointerType::getUnqual(Ctx)), true};
}
}
Expand All @@ -988,6 +1015,11 @@ std::unique_ptr<VariadicABIInfo> VariadicABIInfo::create(const Triple &T) {
return std::make_unique<Wasm>();
}

case Triple::nvptx:
case Triple::nvptx64: {
return std::make_unique<NVPTX>();
}

default:
return {};
}
Expand Down
Loading

0 comments on commit 3517a11

Please sign in to comment.