Skip to content

Commit

Permalink
[NVPTX] Support inline asm with 128-bit operand in NVPTX backend (llv…
Browse files Browse the repository at this point in the history
…m#97113)

This change supports the 128-bit operands for inline ptx asm, both input
and output.\
\
The major changes are:

- Tablegen:\
    Define Int128Regs in NVPTXRegisterInfo.td. But this register does
not set as general register type in NVPTX backend so that this change
will not influence the codegen without inline asm.\
    Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and
I128toV2I64. The first one moves a register, the second one moves two
64-bit registers into one 128-bit register, and the third one just does
the opposite.
- NVPTXISelLowering & NVPTXISelDAGToDAG:\
    Custom lowering CopyToReg and CopyFromReg with 128-bit operands.
CopyToReg deals with the inputs of the inline asm and the CopyFromReg
deals with the outputs.\
    CopyToReg is custom lowered into a V2I64toI128, which takes in the
expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.\
    CopyFromReg is custom lowered by adding a I128toV2I64, which breaks
down the 128-bit outputs of inline asm into the expanded values.
  • Loading branch information
Chengjunp authored and kbluck committed Jul 6, 2024
1 parent 80e4ab1 commit a919fb1
Show file tree
Hide file tree
Showing 15 changed files with 529 additions and 3 deletions.
1 change: 1 addition & 0 deletions clang/lib/Basic/Targets/NVPTX.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
case 'l':
case 'f':
case 'd':
case 'q':
Info.setAllowsRegister();
return true;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5381,6 +5381,7 @@ NVPTX:
- ``c`` or ``h``: A 16-bit integer register.
- ``r``: A 32-bit integer register.
- ``l`` or ``N``: A 64-bit integer register.
- ``q``: A 128-bit integer register.
- ``f``: A 32-bit float register.
- ``d``: A 64-bit float register.

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) const {
case 6:
OS << "%fd";
break;
case 7:
OS << "%rq";
break;
}

unsigned VReg = Reg.id() & 0x0FFFFFFF;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Ret = (5 << 28);
} else if (RC == &NVPTX::Float64RegsRegClass) {
Ret = (6 << 28);
} else if (RC == &NVPTX::Int128RegsRegClass) {
Ret = (7 << 28);
} else {
report_fatal_error("Bad register class");
}
Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,20 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
if (tryConstantFP(N))
return;
break;
case ISD::CopyToReg: {
if (N->getOperand(1).getValueType() == MVT::i128) {
SelectV2I64toI128(N);
return;
}
break;
}
case ISD::CopyFromReg: {
if (N->getOperand(1).getValueType() == MVT::i128) {
SelectI128toV2I64(N);
return;
}
break;
}
default:
break;
}
Expand Down Expand Up @@ -3798,6 +3812,60 @@ bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
return true;
}

void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) {
// Lower a CopyToReg with two 64-bit inputs
// Dst:i128, lo:i64, hi:i64
//
// CopyToReg Dst, lo, hi;
//
// ==>
//
// tmp = V2I64toI128 {lo, hi};
// CopyToReg Dst, tmp;
SDValue Dst = N->getOperand(1);
SDValue Lo = N->getOperand(2);
SDValue Hi = N->getOperand(3);

SDLoc DL(N);
SDNode *Mov =
CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi});

SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1);
NewOps[0] = N->getOperand(0);
NewOps[1] = Dst;
NewOps[2] = SDValue(Mov, 0);
if (N->getNumOperands() == 5)
NewOps[3] = N->getOperand(4);
SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps);

ReplaceNode(N, NewValue.getNode());
}

void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
// Lower CopyFromReg from a 128-bit regs to two 64-bit regs
// Dst:i128, Src:i128
//
// {lo, hi} = CopyFromReg Src
//
// ==>
//
// {lo, hi} = I128toV2I64 Src
//
SDValue Ch = N->getOperand(0);
SDValue Src = N->getOperand(1);
SDValue Glue = N->getOperand(2);
SDLoc DL(N);

// Add Glue and Ch to the operands and results to avoid break the execution
// order
SDNode *Mov = CurDAG->getMachineNode(
NVPTX::I128toV2I64, DL,
{MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()},
{Src, Ch, Glue});

ReplaceNode(N, Mov);
}

/// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
/// conversion from \p SrcTy to \p DestTy.
unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool SelectSETP_F16X2(SDNode *N);
bool SelectSETP_BF16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);

void SelectV2I64toI128(SDNode *N);
void SelectI128toV2I64(SDNode *N);
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
}
Expand Down
88 changes: 88 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
}

// Custom lowering for inline asm with 128-bit operands
setOperationAction(ISD::CopyToReg, MVT::i128, Custom);
setOperationAction(ISD::CopyFromReg, MVT::i128, Custom);

// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.

Expand Down Expand Up @@ -2804,6 +2808,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVectorArith(Op, DAG);
case ISD::DYNAMIC_STACKALLOC:
return LowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
default:
llvm_unreachable("Custom lowering not defined for operation");
}
Expand Down Expand Up @@ -3094,6 +3100,54 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
return Result;
}

SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
SelectionDAG &DAG) const {
// Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
// operand so that it can pass the legalization.

assert(Op.getOperand(1).getValueType() == MVT::i128 &&
"Custom lowering for 128-bit CopyToReg only");

SDNode *Node = Op.getNode();
SDLoc DL(Node);

SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
DAG.getIntPtrConstant(0, DL));
SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
DAG.getIntPtrConstant(1, DL));

SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
SmallVector<EVT, 3> ResultsType(Node->values());

NewOps[0] = Op->getOperand(0); // Chain
NewOps[1] = Op->getOperand(1); // Dst Reg
NewOps[2] = Lo; // Lower 64-bit
NewOps[3] = Hi; // Higher 64-bit
if (Op.getNumOperands() == 4)
NewOps[4] = Op->getOperand(3); // Glue if exists

return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
}

unsigned NVPTXTargetLowering::getNumRegisters(
LLVMContext &Context, EVT VT,
std::optional<MVT> RegisterVT = std::nullopt) const {
if (VT == MVT::i128 && RegisterVT == MVT::i128)
return 1;
return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
}

bool NVPTXTargetLowering::splitValueIntoRegisterParts(
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
if (Val.getValueType() == MVT::i128 && NumParts == 1) {
Parts[0] = Val;
return true;
}
return false;
}

// This creates target external symbol for a function parameter.
// Name of the symbol is composed from its index and the function name.
// Negative index corresponds to special parameter (unsized array) used for
Expand Down Expand Up @@ -5150,6 +5204,7 @@ NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
case 'l':
case 'f':
case 'd':
case 'q':
case '0':
case 'N':
return C_RegisterClass;
Expand All @@ -5175,6 +5230,12 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
case 'l':
case 'N':
return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
case 'q': {
if (STI.getSmVersion() < 70)
report_fatal_error("Inline asm with 128 bit operands is only "
"supported for sm_70 and higher!");
return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
}
case 'f':
return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
case 'd':
Expand Down Expand Up @@ -6261,6 +6322,30 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
}
}

static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
SmallVectorImpl<SDValue> &Results) {
// Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
// result so that it can pass the legalization
SDLoc DL(N);
SDValue Chain = N->getOperand(0);
SDValue Reg = N->getOperand(1);
SDValue Glue = N->getOperand(2);

assert(Reg.getValueType() == MVT::i128 &&
"Custom lowering for CopyFromReg with 128-bit reg only");
SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
N->getValueType(2)};
SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};

SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
{NewValue.getValue(0), NewValue.getValue(1)});

Results.push_back(Pair);
Results.push_back(NewValue.getValue(2));
Results.push_back(NewValue.getValue(3));
}

void NVPTXTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
switch (N->getOpcode()) {
Expand All @@ -6272,6 +6357,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
case ISD::INTRINSIC_W_CHAIN:
ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
return;
case ISD::CopyFromReg:
ReplaceCopyFromReg_128(N, DAG, Results);
return;
}
}

Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
unsigned getNumRegisters(LLVMContext &Context, EVT VT,
std::optional<MVT> RegisterVT) const override;
bool
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
SDValue *Parts, unsigned NumParts, MVT PartVT,
std::optional<CallingConv::ID> CC) const override;

void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const override;
SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
} else if (DestRC == &NVPTX::Int64RegsRegClass) {
Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr
: NVPTX::BITCONVERT_64_F2I);
} else if (DestRC == &NVPTX::Int128RegsRegClass) {
Op = NVPTX::IMOV128rr;
} else if (DestRC == &NVPTX::Float32RegsRegClass) {
Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
: NVPTX::BITCONVERT_32_I2F);
Expand Down
12 changes: 10 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,8 @@ let IsSimpleMove=1, hasSideEffects=0 in {
"mov.u32 \t$dst, $sss;", []>;
def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss),
"mov.u64 \t$dst, $sss;", []>;
def IMOV128rr : NVPTXInst<(outs Int128Regs:$dst), (ins Int128Regs:$sss),
"mov.b128 \t$dst, $sss;", []>;

def IMOVB16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$sss),
"mov.b16 \t$dst, $sss;", []>;
Expand Down Expand Up @@ -3545,6 +3547,9 @@ let hasSideEffects = false in {
def V2I32toI64 : NVPTXInst<(outs Int64Regs:$d),
(ins Int32Regs:$s1, Int32Regs:$s2),
"mov.b64 \t$d, {{$s1, $s2}};", []>;
def V2I64toI128 : NVPTXInst<(outs Int128Regs:$d),
(ins Int64Regs:$s1, Int64Regs:$s2),
"mov.b128 \t$d, {{$s1, $s2}};", []>;
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
(ins Float32Regs:$s1, Float32Regs:$s2),
"mov.b64 \t$d, {{$s1, $s2}};", []>;
Expand All @@ -3560,6 +3565,9 @@ let hasSideEffects = false in {
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
(ins Int64Regs:$s),
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
(ins Int128Regs:$s),
"mov.b128 \t{{$d1, $d2}}, $s;", []>;
def F64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
(ins Float64Regs:$s),
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
Expand Down Expand Up @@ -3629,7 +3637,7 @@ def : Pat<(i32 (ctlz (i32 Int32Regs:$a))), (CLZr32 Int32Regs:$a)>;
// ptx value to 64 bits to match the ISD node's semantics, unless we know we're
// truncating back down to 32 bits.
def : Pat<(i64 (ctlz Int64Regs:$a)), (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>;
def : Pat<(i32 (trunc (ctlz Int64Regs:$a))), (CLZr64 Int64Regs:$a)>;
def : Pat<(i32 (trunc (i64 (ctlz Int64Regs:$a)))), (CLZr64 Int64Regs:$a)>;

// For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the
// result back to 16-bits if necessary. We also need to subtract 16 because
Expand Down Expand Up @@ -3667,7 +3675,7 @@ def : Pat<(i32 (ctpop (i32 Int32Regs:$a))), (POPCr32 Int32Regs:$a)>;
// pattern that avoids the type conversion if we're truncating the result to
// i32 anyway.
def : Pat<(ctpop Int64Regs:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>;
def : Pat<(i32 (trunc (ctpop Int64Regs:$a))), (POPCr64 Int64Regs:$a)>;
def : Pat<(i32 (trunc (i64 (ctpop Int64Regs:$a)))), (POPCr64 Int64Regs:$a)>;

// For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits.
// If we know that we're storing into an i32, we can avoid the final trunc.
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
return ".f32";
if (RC == &NVPTX::Float64RegsRegClass)
return ".f64";
if (RC == &NVPTX::Int128RegsRegClass)
return ".b128";
if (RC == &NVPTX::Int64RegsRegClass)
// We use untyped (.b) integer registers here as NVCC does.
// Correctness of generated code does not depend on register type,
Expand Down Expand Up @@ -67,6 +69,8 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
return "%f";
if (RC == &NVPTX::Float64RegsRegClass)
return "%fd";
if (RC == &NVPTX::Int128RegsRegClass)
return "%rq";
if (RC == &NVPTX::Int64RegsRegClass)
return "%rd";
if (RC == &NVPTX::Int32RegsRegClass)
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ foreach i = 0...4 in {
def RS#i : NVPTXReg<"%rs"#i>; // 16-bit
def R#i : NVPTXReg<"%r"#i>; // 32-bit
def RL#i : NVPTXReg<"%rd"#i>; // 64-bit
def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
def H#i : NVPTXReg<"%h"#i>; // 16-bit float
def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
def F#i : NVPTXReg<"%f"#i>; // 32-bit float
Expand All @@ -62,6 +63,8 @@ def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
(add (sequence "R%u", 0, 4),
VRFrame32, VRFrameLocal32)>;
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
Expand Down
Loading

0 comments on commit a919fb1

Please sign in to comment.