diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp index ff3f9c4d6e7786..a4df0b09177ab7 100644 --- a/flang/lib/Optimizer/CodeGen/Target.cpp +++ b/flang/lib/Optimizer/CodeGen/Target.cpp @@ -604,6 +604,36 @@ struct TargetX86_64 : public GenericTarget { return {}; } + mlir::Type pickLLVMArgType(mlir::Location loc, mlir::MLIRContext *context, + ArgClass argClass, + std::uint64_t partByteSize) const { + if (argClass == ArgClass::SSE) { + if (partByteSize > 16) + TODO(loc, "passing struct as a real > 128 bits in register"); + // Clang uses vector type when several fp fields are marshalled + // into a single SSE register (like ). + // It should make no difference from an ABI point of view to just + // select an fp type of the right size, and it makes things simpler + // here. + if (partByteSize > 8) + return mlir::FloatType::getF128(context); + if (partByteSize > 4) + return mlir::FloatType::getF64(context); + if (partByteSize > 2) + return mlir::FloatType::getF32(context); + return mlir::FloatType::getF16(context); + } + assert(partByteSize <= 8 && + "expect integer part of aggregate argument to fit into eight bytes"); + if (partByteSize > 4) + return mlir::IntegerType::get(context, 64); + if (partByteSize > 2) + return mlir::IntegerType::get(context, 32); + if (partByteSize > 1) + return mlir::IntegerType::get(context, 16); + return mlir::IntegerType::get(context, 8); + } + /// Marshal a derived type passed by value like a C struct. CodeGenSpecifics::Marshalling structArgumentType(mlir::Location loc, fir::RecordType recTy, @@ -638,9 +668,29 @@ struct TargetX86_64 : public GenericTarget { marshal.emplace_back(fieldType, AT{}); return marshal; } - // TODO, marshal the struct with several components, or with a single - // complex, array, or derived type component into registers. - TODO(loc, "passing BIND(C), VALUE derived type in registers on X86-64"); + if (Hi == ArgClass::NoClass || Hi == ArgClass::SSEUp) { + // Pass a single integer or floating point argument. + mlir::Type lowType = + pickLLVMArgType(loc, recTy.getContext(), Lo, byteOffset); + CodeGenSpecifics::Marshalling marshal; + marshal.emplace_back(lowType, AT{}); + return marshal; + } + // Split into two integer or floating point arguments. + // Note that for the first argument, this will always pick i64 or f64 which + // may be bigger than needed if some struct padding ends the first eight + // byte (e.g. for `{i32, f64}`). It is valid from an X86-64 ABI and + // semantic point of view, but it may not match the LLVM IR interface clang + // would produce for the equivalent C code (the assembly will still be + // compatible). This allows keeping the logic simpler here since it + // avoids computing the "data" size of the Lo part. + mlir::Type lowType = pickLLVMArgType(loc, recTy.getContext(), Lo, 8u); + mlir::Type hiType = + pickLLVMArgType(loc, recTy.getContext(), Hi, byteOffset - 8u); + CodeGenSpecifics::Marshalling marshal; + marshal.emplace_back(lowType, AT{}); + marshal.emplace_back(hiType, AT{}); + return marshal; } /// Marshal an argument that must be passed on the stack. diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index 2f5c8cc0071ae1..f324e18c65465f 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -180,11 +180,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { // We are going to generate an alloca, so save the stack pointer. if (!savedStackPtr) savedStackPtr = genStackSave(loc); - auto mem = rewriter->create(loc, resTy); - rewriter->create(loc, call->getResult(0), mem); - auto memTy = fir::ReferenceType::get(ty); - auto cast = rewriter->create(loc, memTy, mem); - return rewriter->create(loc, cast); + return this->convertValueInMemory(loc, call->getResult(0), ty, + /*inputMayBeBigger=*/true); }; } @@ -195,7 +192,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { mlir::Value &savedStackPtr) { auto resTy = std::get(newTypeAndAttr); auto attr = std::get(newTypeAndAttr); - auto oldRefTy = fir::ReferenceType::get(oldType); // We are going to generate an alloca, so save the stack pointer. if (!savedStackPtr) savedStackPtr = genStackSave(loc); @@ -206,11 +202,83 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { mem = rewriter->create(loc, resTy, mem); newOpers.push_back(mem); } else { - auto mem = rewriter->create(loc, resTy); + mlir::Value bitcast = + convertValueInMemory(loc, oper, resTy, /*inputMayBeBigger=*/false); + newOpers.push_back(bitcast); + } + } + + // Do a bitcast (convert a value via its memory representation). + // The input and output types may have different storage sizes, + // "inputMayBeBigger" should be set to indicate which of the input or + // output type may be bigger in order for the load/store to be safe. + // The mismatch comes from the fact that the LLVM register used for passing + // may be bigger than the value being passed (e.g., passing + // a `!fir.type}>` into an i32 LLVM register). + mlir::Value convertValueInMemory(mlir::Location loc, mlir::Value value, + mlir::Type newType, bool inputMayBeBigger) { + if (inputMayBeBigger) { + auto newRefTy = fir::ReferenceType::get(newType); + auto mem = rewriter->create(loc, value.getType()); + rewriter->create(loc, value, mem); + auto cast = rewriter->create(loc, newRefTy, mem); + return rewriter->create(loc, cast); + } else { + auto oldRefTy = fir::ReferenceType::get(value.getType()); + auto mem = rewriter->create(loc, newType); auto cast = rewriter->create(loc, oldRefTy, mem); - rewriter->create(loc, oper, cast); - newOpers.push_back(rewriter->create(loc, mem)); + rewriter->create(loc, value, cast); + return rewriter->create(loc, mem); + } + } + + void passSplitArgument(mlir::Location loc, + fir::CodeGenSpecifics::Marshalling splitArgs, + mlir::Type oldType, mlir::Value oper, + llvm::SmallVectorImpl &newOpers, + mlir::Value &savedStackPtr) { + // COMPLEX or struct argument split into separate arguments + if (!fir::isa_complex(oldType)) { + // Cast original operand to a tuple of the new arguments + // via memory. + llvm::SmallVector partTypes; + for (auto argPart : splitArgs) + partTypes.push_back(std::get(argPart)); + mlir::Type tupleType = + mlir::TupleType::get(oldType.getContext(), partTypes); + if (!savedStackPtr) + savedStackPtr = genStackSave(loc); + oper = convertValueInMemory(loc, oper, tupleType, + /*inputMayBeBigger=*/false); + } + auto iTy = rewriter->getIntegerType(32); + for (auto e : llvm::enumerate(splitArgs)) { + auto &tup = e.value(); + auto ty = std::get(tup); + auto index = e.index(); + auto idx = rewriter->getIntegerAttr(iTy, index); + auto val = rewriter->create( + loc, ty, oper, rewriter->getArrayAttr(idx)); + newOpers.push_back(val); + } + } + + void rewriteCallOperands( + mlir::Location loc, fir::CodeGenSpecifics::Marshalling passArgAs, + mlir::Type originalArgTy, mlir::Value oper, + llvm::SmallVectorImpl &newOpers, mlir::Value &savedStackPtr, + fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) { + if (passArgAs.size() == 1) { + // COMPLEX or derived type is passed as a single argument. + passArgumentOnStackOrWithNewType(loc, passArgAs[0], originalArgTy, oper, + newOpers, savedStackPtr); + } else { + // COMPLEX or derived type is split into separate arguments + passSplitArgument(loc, passArgAs, originalArgTy, oper, newOpers, + savedStackPtr); } + newInTyAndAttrs.insert(newInTyAndAttrs.end(), passArgAs.begin(), + passArgAs.end()); } template @@ -224,28 +292,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { newOpers.push_back(oper); return; } - auto m = specifics->complexArgumentType(loc, ty.getElementType()); - if (m.size() == 1) { - // COMPLEX is a single aggregate - passArgumentOnStackOrWithNewType(loc, m[0], ty, oper, newOpers, - savedStackPtr); - newInTyAndAttrs.push_back(m[0]); - } else { - assert(m.size() == 2); - // COMPLEX is split into 2 separate arguments - auto iTy = rewriter->getIntegerType(32); - for (auto e : llvm::enumerate(m)) { - auto &tup = e.value(); - auto ty = std::get(tup); - auto index = e.index(); - auto idx = rewriter->getIntegerAttr(iTy, index); - auto val = rewriter->create( - loc, ty, oper, rewriter->getArrayAttr(idx)); - newInTyAndAttrs.push_back(tup); - newOpers.push_back(val); - } - } + rewriteCallOperands(loc, m, ty, oper, newOpers, savedStackPtr, + newInTyAndAttrs); } void rewriteCallStructInputType( @@ -260,11 +309,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { } auto structArgs = specifics->structArgumentType(loc, recTy, newInTyAndAttrs); - if (structArgs.size() != 1) - TODO(loc, "splitting BIND(C), VALUE derived type into several arguments"); - passArgumentOnStackOrWithNewType(loc, structArgs[0], recTy, oper, newOpers, - savedStackPtr); - structArgs.push_back(structArgs[0]); + rewriteCallOperands(loc, structArgs, recTy, oper, newOpers, savedStackPtr, + newInTyAndAttrs); } static bool hasByValOrSRetArgs( @@ -849,8 +895,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { case FixupTy::Codes::ArgumentType: { // Argument is pass-by-value, but its type has likely been modified to // suit the target ABI convention. - auto oldArgTy = - fir::ReferenceType::get(oldArgTys[fixup.index - offset]); + auto oldArgTy = oldArgTys[fixup.index - offset]; // If type did not change, keep the original argument. if (fixupType == oldArgTy) break; @@ -858,15 +903,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { auto newArg = func.front().insertArgument(fixup.index, fixupType, loc); rewriter->setInsertionPointToStart(&func.front()); - auto mem = rewriter->create(loc, fixupType); - rewriter->create(loc, newArg, mem); - auto cast = rewriter->create(loc, oldArgTy, mem); - mlir::Value load = rewriter->create(loc, cast); - func.getArgument(fixup.index + 1).replaceAllUsesWith(load); + mlir::Value bitcast = convertValueInMemory(loc, newArg, oldArgTy, + /*inputMayBeBigger=*/true); + func.getArgument(fixup.index + 1).replaceAllUsesWith(bitcast); func.front().eraseArgument(fixup.index + 1); LLVM_DEBUG(llvm::dbgs() - << "old argument: " << oldArgTy.getEleTy() - << ", repl: " << load << ", new argument: " + << "old argument: " << oldArgTy << ", repl: " << bitcast + << ", new argument: " << func.getArgument(fixup.index).getType() << '\n'); } break; case FixupTy::Codes::CharPair: { @@ -907,34 +950,43 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { func.walk([&](mlir::func::ReturnOp ret) { rewriter->setInsertionPoint(ret); auto oldOper = ret.getOperand(0); - auto oldOperTy = fir::ReferenceType::get(oldOper.getType()); - auto mem = - rewriter->create(loc, newResTys[fixup.index]); - auto cast = rewriter->create(loc, oldOperTy, mem); - rewriter->create(loc, oldOper, cast); - mlir::Value load = rewriter->create(loc, mem); - rewriter->create(loc, load); + mlir::Value bitcast = + convertValueInMemory(loc, oldOper, newResTys[fixup.index], + /*inputMayBeBigger=*/false); + rewriter->create(loc, bitcast); ret.erase(); }); } break; case FixupTy::Codes::Split: { // The FIR argument has been split into a pair of distinct arguments - // that are in juxtaposition to each other. (For COMPLEX value.) + // that are in juxtaposition to each other. (For COMPLEX value or + // derived type passed with VALUE in BIND(C) context). auto newArg = func.front().insertArgument(fixup.index, fixupType, loc); if (fixup.second == 1) { rewriter->setInsertionPointToStart(&func.front()); - auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; - auto undef = rewriter->create(loc, cplxTy); + mlir::Value firstArg = func.front().getArgument(fixup.index - 1); + mlir::Type originalTy = + oldArgTys[fixup.index - offset - fixup.second]; + mlir::Type pairTy = originalTy; + if (!fir::isa_complex(originalTy)) { + pairTy = mlir::TupleType::get( + originalTy.getContext(), + mlir::TypeRange{firstArg.getType(), newArg.getType()}); + } + auto undef = rewriter->create(loc, pairTy); auto iTy = rewriter->getIntegerType(32); auto zero = rewriter->getIntegerAttr(iTy, 0); auto one = rewriter->getIntegerAttr(iTy, 1); - auto cplx1 = rewriter->create( - loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), - rewriter->getArrayAttr(zero)); - auto cplx = rewriter->create( - loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); - func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); + mlir::Value pair1 = rewriter->create( + loc, pairTy, undef, firstArg, rewriter->getArrayAttr(zero)); + mlir::Value pair = rewriter->create( + loc, pairTy, pair1, newArg, rewriter->getArrayAttr(one)); + // Cast local argument tuple to original type via memory if needed. + if (pairTy != originalTy) + pair = convertValueInMemory(loc, pair, originalTy, + /*inputMayBeBigger=*/true); + func.getArgument(fixup.index + 1).replaceAllUsesWith(pair); func.front().eraseArgument(fixup.index + 1); offset++; } diff --git a/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir new file mode 100644 index 00000000000000..82139492cea700 --- /dev/null +++ b/flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir @@ -0,0 +1,159 @@ +// Test X86-64 passing ABI of struct in registers for the cases where the +// struct has more than one field. +// REQUIRES: x86-registered-target +// RUN: fir-opt -target-rewrite="target=x86_64-unknown-linux-gnu" %s -o - | FileCheck %s + + +module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + +func.func @test_call_i8_a16(%0 : !fir.ref}>>) { + %1 = fir.load %0 : !fir.ref}>> + fir.call @test_func_i8_a16(%1) : (!fir.type}>) -> () + return +} +// CHECK-LABEL: func.func @test_call_i8_a16( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref}>>) { +// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref}>> +// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref +// CHECK: %[[VAL_3:.*]] = fir.alloca tuple +// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref>) -> !fir.ref}>> +// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref}>> +// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref> +// CHECK: %[[VAL_6:.*]] = fir.extract_value %[[VAL_5]], [0 : i32] : (tuple) -> i64 +// CHECK: %[[VAL_7:.*]] = fir.extract_value %[[VAL_5]], [1 : i32] : (tuple) -> i64 +// CHECK: fir.call @test_func_i8_a16(%[[VAL_6]], %[[VAL_7]]) : (i64, i64) -> () +// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref) -> () +// CHECK: return + +func.func private @test_func_i8_a16(%0 : !fir.type}>) -> () { + return +} +// CHECK-LABEL: func.func private @test_func_i8_a16( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: i64) { +// CHECK: %[[VAL_2:.*]] = fir.undefined tuple +// CHECK: %[[VAL_3:.*]] = fir.insert_value %[[VAL_2]], %[[VAL_0]], [0 : i32] : (tuple, i64) -> tuple +// CHECK: %[[VAL_4:.*]] = fir.insert_value %[[VAL_3]], %[[VAL_1]], [1 : i32] : (tuple, i64) -> tuple +// CHECK: %[[VAL_5:.*]] = fir.alloca tuple +// CHECK: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref> +// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref>) -> !fir.ref}>> +// CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_6]] : !fir.ref}>> +// CHECK: return + + +// For the cases below, the argument marshalling logic is the same as above, +// so only the chosen signature is tested at the end. + +func.func @test_call_i32_f32(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_i32_f32(%1) : (!fir.type) -> () + return +} +func.func private @test_func_i32_f32(%0 : !fir.type) -> () { + return +} + +func.func @test_call_i32_i16(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_i32_i16(%1) : (!fir.type) -> () + return +} +func.func private @test_func_i32_i16(%0 : !fir.type) -> () { + return +} + +func.func @test_call_f16_i16(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_f16_i16(%1) : (!fir.type) -> () + return +} +func.func private @test_func_f16_i16(%0 : !fir.type) -> () { + return +} + +func.func @test_call_f16_f16(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_f16_f16(%1) : (!fir.type) -> () + return +} +func.func private @test_func_f16_f16(%0 : !fir.type) -> () { + return +} + +func.func @test_call_i32_f64(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_i32_f64(%1) : (!fir.type) -> () + return +} +func.func private @test_func_i32_f64(%0 : !fir.type) -> () { + return +} + +func.func @test_call_f64_f32(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_f64_f32(%1) : (!fir.type) -> () + return +} +func.func private @test_func_f64_f32(%0 : !fir.type) -> () { + return +} + +func.func @test_call_f32_i32_f32_f32(%0 : !fir.ref>) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_f32_i32_f32_f32(%1) : (!fir.type) -> () + return +} +func.func private @test_func_f32_i32_f32_f32(%0 : !fir.type) -> () { + return +} + +func.func @test_call_f64_i32(%before : i16, %0 : !fir.ref>, %after : f128) { + %1 = fir.load %0 : !fir.ref> + fir.call @test_func_f64_i32(%before, %1, %after) : (i16, !fir.type, f128) -> () + return +} +func.func private @test_func_f64_i32(%before : i16, %0 : !fir.type, %after : f128) -> () { + return +} +} + +// CHECK-LABEL: func.func @test_call_i32_f32( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_i32_f32( +// CHECK-SAME: %[[VAL_0:.*]]: i64) { +// CHECK-LABEL: func.func @test_call_i32_i16( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_i32_i16( +// CHECK-SAME: %[[VAL_0:.*]]: i64) { +// CHECK-LABEL: func.func @test_call_f16_i16( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_f16_i16( +// CHECK-SAME: %[[VAL_0:.*]]: i32) { +// CHECK-LABEL: func.func @test_call_f16_f16( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_f16_f16( +// CHECK-SAME: %[[VAL_0:.*]]: f32) { +// CHECK-LABEL: func.func @test_call_i32_f64( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_i32_f64( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: f64) { +// CHECK-LABEL: func.func @test_call_f64_f32( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_f64_f32( +// CHECK-SAME: %[[VAL_0:.*]]: f64, +// CHECK-SAME: %[[VAL_1:.*]]: f32) { +// CHECK-LABEL: func.func @test_call_f32_i32_f32_f32( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK-LABEL: func.func private @test_func_f32_i32_f32_f32( +// CHECK-SAME: %[[VAL_0:.*]]: i64, +// CHECK-SAME: %[[VAL_1:.*]]: f64) { +// CHECK-LABEL: func.func @test_call_f64_i32( +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref>, +// CHECK-SAME: %[[VAL_2:.*]]: f128) { +// CHECK-LABEL: func.func private @test_func_f64_i32( +// CHECK-SAME: %[[VAL_0:.*]]: i16, +// CHECK-SAME: %[[VAL_1:.*]]: f64, +// CHECK-SAME: %[[VAL_2:.*]]: i32, +// CHECK-SAME: %[[VAL_3:.*]]: f128) {