Skip to content

Commit

Permalink
[flang] finish BIND(C) VALUE derived type passing ABI on X86-64 (llvm…
Browse files Browse the repository at this point in the history
…#77742)

Derived type passed with VALUE in BIND(C) context must be passed like C
struct and LLVM is not implementing the ABI for this (it is up to the
frontends like clang).

Previous patch llvm#75802 implemented the simple cases where the derived
type have one field, this patch implements the general case. Note that
the generated LLVM IR is compliant from a X86-64 C ABI point of view and
compatible with clang generated assembly, but that it is not guaranteed
to match the LLVM IR signatures generated by clang for the C equivalent
functions because several LLVM IR signatures may lead to the same X86-64
signature.
  • Loading branch information
jeanPerier authored Jan 12, 2024
1 parent c65b939 commit 011ba72
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 63 deletions.
56 changes: 53 additions & 3 deletions flang/lib/Optimizer/CodeGen/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,36 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
return {};
}

mlir::Type pickLLVMArgType(mlir::Location loc, mlir::MLIRContext *context,
ArgClass argClass,
std::uint64_t partByteSize) const {
if (argClass == ArgClass::SSE) {
if (partByteSize > 16)
TODO(loc, "passing struct as a real > 128 bits in register");
// Clang uses vector type when several fp fields are marshalled
// into a single SSE register (like <n x smallest fp field> ).
// It should make no difference from an ABI point of view to just
// select an fp type of the right size, and it makes things simpler
// here.
if (partByteSize > 8)
return mlir::FloatType::getF128(context);
if (partByteSize > 4)
return mlir::FloatType::getF64(context);
if (partByteSize > 2)
return mlir::FloatType::getF32(context);
return mlir::FloatType::getF16(context);
}
assert(partByteSize <= 8 &&
"expect integer part of aggregate argument to fit into eight bytes");
if (partByteSize > 4)
return mlir::IntegerType::get(context, 64);
if (partByteSize > 2)
return mlir::IntegerType::get(context, 32);
if (partByteSize > 1)
return mlir::IntegerType::get(context, 16);
return mlir::IntegerType::get(context, 8);
}

/// Marshal a derived type passed by value like a C struct.
CodeGenSpecifics::Marshalling
structArgumentType(mlir::Location loc, fir::RecordType recTy,
Expand Down Expand Up @@ -638,9 +668,29 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
marshal.emplace_back(fieldType, AT{});
return marshal;
}
// TODO, marshal the struct with several components, or with a single
// complex, array, or derived type component into registers.
TODO(loc, "passing BIND(C), VALUE derived type in registers on X86-64");
if (Hi == ArgClass::NoClass || Hi == ArgClass::SSEUp) {
// Pass a single integer or floating point argument.
mlir::Type lowType =
pickLLVMArgType(loc, recTy.getContext(), Lo, byteOffset);
CodeGenSpecifics::Marshalling marshal;
marshal.emplace_back(lowType, AT{});
return marshal;
}
// Split into two integer or floating point arguments.
// Note that for the first argument, this will always pick i64 or f64 which
// may be bigger than needed if some struct padding ends the first eight
// byte (e.g. for `{i32, f64}`). It is valid from an X86-64 ABI and
// semantic point of view, but it may not match the LLVM IR interface clang
// would produce for the equivalent C code (the assembly will still be
// compatible). This allows keeping the logic simpler here since it
// avoids computing the "data" size of the Lo part.
mlir::Type lowType = pickLLVMArgType(loc, recTy.getContext(), Lo, 8u);
mlir::Type hiType =
pickLLVMArgType(loc, recTy.getContext(), Hi, byteOffset - 8u);
CodeGenSpecifics::Marshalling marshal;
marshal.emplace_back(lowType, AT{});
marshal.emplace_back(hiType, AT{});
return marshal;
}

/// Marshal an argument that must be passed on the stack.
Expand Down
172 changes: 112 additions & 60 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// We are going to generate an alloca, so save the stack pointer.
if (!savedStackPtr)
savedStackPtr = genStackSave(loc);
auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
auto memTy = fir::ReferenceType::get(ty);
auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem);
return rewriter->create<fir::LoadOp>(loc, cast);
return this->convertValueInMemory(loc, call->getResult(0), ty,
/*inputMayBeBigger=*/true);
};
}

Expand All @@ -195,7 +192,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mlir::Value &savedStackPtr) {
auto resTy = std::get<mlir::Type>(newTypeAndAttr);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
auto oldRefTy = fir::ReferenceType::get(oldType);
// We are going to generate an alloca, so save the stack pointer.
if (!savedStackPtr)
savedStackPtr = genStackSave(loc);
Expand All @@ -206,11 +202,83 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mem = rewriter->create<fir::ConvertOp>(loc, resTy, mem);
newOpers.push_back(mem);
} else {
auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
mlir::Value bitcast =
convertValueInMemory(loc, oper, resTy, /*inputMayBeBigger=*/false);
newOpers.push_back(bitcast);
}
}

// Do a bitcast (convert a value via its memory representation).
// The input and output types may have different storage sizes,
// "inputMayBeBigger" should be set to indicate which of the input or
// output type may be bigger in order for the load/store to be safe.
// The mismatch comes from the fact that the LLVM register used for passing
// may be bigger than the value being passed (e.g., passing
// a `!fir.type<t{fir.array<3xi8>}>` into an i32 LLVM register).
mlir::Value convertValueInMemory(mlir::Location loc, mlir::Value value,
mlir::Type newType, bool inputMayBeBigger) {
if (inputMayBeBigger) {
auto newRefTy = fir::ReferenceType::get(newType);
auto mem = rewriter->create<fir::AllocaOp>(loc, value.getType());
rewriter->create<fir::StoreOp>(loc, value, mem);
auto cast = rewriter->create<fir::ConvertOp>(loc, newRefTy, mem);
return rewriter->create<fir::LoadOp>(loc, cast);
} else {
auto oldRefTy = fir::ReferenceType::get(value.getType());
auto mem = rewriter->create<fir::AllocaOp>(loc, newType);
auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
rewriter->create<fir::StoreOp>(loc, oper, cast);
newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
rewriter->create<fir::StoreOp>(loc, value, cast);
return rewriter->create<fir::LoadOp>(loc, mem);
}
}

void passSplitArgument(mlir::Location loc,
fir::CodeGenSpecifics::Marshalling splitArgs,
mlir::Type oldType, mlir::Value oper,
llvm::SmallVectorImpl<mlir::Value> &newOpers,
mlir::Value &savedStackPtr) {
// COMPLEX or struct argument split into separate arguments
if (!fir::isa_complex(oldType)) {
// Cast original operand to a tuple of the new arguments
// via memory.
llvm::SmallVector<mlir::Type> partTypes;
for (auto argPart : splitArgs)
partTypes.push_back(std::get<mlir::Type>(argPart));
mlir::Type tupleType =
mlir::TupleType::get(oldType.getContext(), partTypes);
if (!savedStackPtr)
savedStackPtr = genStackSave(loc);
oper = convertValueInMemory(loc, oper, tupleType,
/*inputMayBeBigger=*/false);
}
auto iTy = rewriter->getIntegerType(32);
for (auto e : llvm::enumerate(splitArgs)) {
auto &tup = e.value();
auto ty = std::get<mlir::Type>(tup);
auto index = e.index();
auto idx = rewriter->getIntegerAttr(iTy, index);
auto val = rewriter->create<fir::ExtractValueOp>(
loc, ty, oper, rewriter->getArrayAttr(idx));
newOpers.push_back(val);
}
}

void rewriteCallOperands(
mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs,
mlir::Type originalArgTy, mlir::Value oper,
llvm::SmallVectorImpl<mlir::Value> &newOpers, mlir::Value &savedStackPtr,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
if (passArgAs.size() == 1) {
// COMPLEX or derived type is passed as a single argument.
passArgumentOnStackOrWithNewType(loc, passArgAs[0], originalArgTy, oper,
newOpers, savedStackPtr);
} else {
// COMPLEX or derived type is split into separate arguments
passSplitArgument(loc, passArgAs, originalArgTy, oper, newOpers,
savedStackPtr);
}
newInTyAndAttrs.insert(newInTyAndAttrs.end(), passArgAs.begin(),
passArgAs.end());
}

template <typename CPLX>
Expand All @@ -224,28 +292,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.push_back(oper);
return;
}

auto m = specifics->complexArgumentType(loc, ty.getElementType());
if (m.size() == 1) {
// COMPLEX is a single aggregate
passArgumentOnStackOrWithNewType(loc, m[0], ty, oper, newOpers,
savedStackPtr);
newInTyAndAttrs.push_back(m[0]);
} else {
assert(m.size() == 2);
// COMPLEX is split into 2 separate arguments
auto iTy = rewriter->getIntegerType(32);
for (auto e : llvm::enumerate(m)) {
auto &tup = e.value();
auto ty = std::get<mlir::Type>(tup);
auto index = e.index();
auto idx = rewriter->getIntegerAttr(iTy, index);
auto val = rewriter->create<fir::ExtractValueOp>(
loc, ty, oper, rewriter->getArrayAttr(idx));
newInTyAndAttrs.push_back(tup);
newOpers.push_back(val);
}
}
rewriteCallOperands(loc, m, ty, oper, newOpers, savedStackPtr,
newInTyAndAttrs);
}

void rewriteCallStructInputType(
Expand All @@ -260,11 +309,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
auto structArgs =
specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
if (structArgs.size() != 1)
TODO(loc, "splitting BIND(C), VALUE derived type into several arguments");
passArgumentOnStackOrWithNewType(loc, structArgs[0], recTy, oper, newOpers,
savedStackPtr);
structArgs.push_back(structArgs[0]);
rewriteCallOperands(loc, structArgs, recTy, oper, newOpers, savedStackPtr,
newInTyAndAttrs);
}

static bool hasByValOrSRetArgs(
Expand Down Expand Up @@ -849,24 +895,21 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
case FixupTy::Codes::ArgumentType: {
// Argument is pass-by-value, but its type has likely been modified to
// suit the target ABI convention.
auto oldArgTy =
fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
auto oldArgTy = oldArgTys[fixup.index - offset];
// If type did not change, keep the original argument.
if (fixupType == oldArgTy)
break;

auto newArg =
func.front().insertArgument(fixup.index, fixupType, loc);
rewriter->setInsertionPointToStart(&func.front());
auto mem = rewriter->create<fir::AllocaOp>(loc, fixupType);
rewriter->create<fir::StoreOp>(loc, newArg, mem);
auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem);
mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
mlir::Value bitcast = convertValueInMemory(loc, newArg, oldArgTy,
/*inputMayBeBigger=*/true);
func.getArgument(fixup.index + 1).replaceAllUsesWith(bitcast);
func.front().eraseArgument(fixup.index + 1);
LLVM_DEBUG(llvm::dbgs()
<< "old argument: " << oldArgTy.getEleTy()
<< ", repl: " << load << ", new argument: "
<< "old argument: " << oldArgTy << ", repl: " << bitcast
<< ", new argument: "
<< func.getArgument(fixup.index).getType() << '\n');
} break;
case FixupTy::Codes::CharPair: {
Expand Down Expand Up @@ -907,34 +950,43 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
func.walk([&](mlir::func::ReturnOp ret) {
rewriter->setInsertionPoint(ret);
auto oldOper = ret.getOperand(0);
auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
auto mem =
rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem);
rewriter->create<fir::StoreOp>(loc, oldOper, cast);
mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
rewriter->create<mlir::func::ReturnOp>(loc, load);
mlir::Value bitcast =
convertValueInMemory(loc, oldOper, newResTys[fixup.index],
/*inputMayBeBigger=*/false);
rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
ret.erase();
});
} break;
case FixupTy::Codes::Split: {
// The FIR argument has been split into a pair of distinct arguments
// that are in juxtaposition to each other. (For COMPLEX value.)
// that are in juxtaposition to each other. (For COMPLEX value or
// derived type passed with VALUE in BIND(C) context).
auto newArg =
func.front().insertArgument(fixup.index, fixupType, loc);
if (fixup.second == 1) {
rewriter->setInsertionPointToStart(&func.front());
auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy);
mlir::Value firstArg = func.front().getArgument(fixup.index - 1);
mlir::Type originalTy =
oldArgTys[fixup.index - offset - fixup.second];
mlir::Type pairTy = originalTy;
if (!fir::isa_complex(originalTy)) {
pairTy = mlir::TupleType::get(
originalTy.getContext(),
mlir::TypeRange{firstArg.getType(), newArg.getType()});
}
auto undef = rewriter->create<fir::UndefOp>(loc, pairTy);
auto iTy = rewriter->getIntegerType(32);
auto zero = rewriter->getIntegerAttr(iTy, 0);
auto one = rewriter->getIntegerAttr(iTy, 1);
auto cplx1 = rewriter->create<fir::InsertValueOp>(
loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
rewriter->getArrayAttr(zero));
auto cplx = rewriter->create<fir::InsertValueOp>(
loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
mlir::Value pair1 = rewriter->create<fir::InsertValueOp>(
loc, pairTy, undef, firstArg, rewriter->getArrayAttr(zero));
mlir::Value pair = rewriter->create<fir::InsertValueOp>(
loc, pairTy, pair1, newArg, rewriter->getArrayAttr(one));
// Cast local argument tuple to original type via memory if needed.
if (pairTy != originalTy)
pair = convertValueInMemory(loc, pair, originalTy,
/*inputMayBeBigger=*/true);
func.getArgument(fixup.index + 1).replaceAllUsesWith(pair);
func.front().eraseArgument(fixup.index + 1);
offset++;
}
Expand Down
Loading

0 comments on commit 011ba72

Please sign in to comment.