Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Support inline asm with 128-bit operand in NVPTX backend #97113

Merged

Conversation

Chengjunp
Copy link
Contributor

This change supports the 128-bit operands for inline ptx asm, both input and output.

The major changes are:

  • Tablegen:
        Define Int128Regs in NVPTXRegisterInfo.td. But this register does not set as general register type in NVPTX backend so that this change will not influence the codegen without inline asm.
        Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and I128toV2I64. The first one moves a register, the second one moves two 64-bit registers into one 128-bit register, and the third one just does the opposite.
  • NVPTXISelLowering & NVPTXISelDAGToDAG:
        Custom lowering CopyToReg and CopyFromReg with 128-bit operands. CopyToReg deals with the inputs of the inline asm and the CopyFromReg deals with the outputs.
        CopyToReg is custom lowered into a V2I64toI128, which takes in the expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.
        CopyFromReg is custom lowered by adding a I128toV2I64, which breaks down the 128-bit outputs of inline asm into the expanded values.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" backend:NVPTX labels Jun 28, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-nvptx

Author: None (Chengjunp)

Changes

This change supports the 128-bit operands for inline ptx asm, both input and output.

The major changes are:

  • Tablegen:
        Define Int128Regs in NVPTXRegisterInfo.td. But this register does not set as general register type in NVPTX backend so that this change will not influence the codegen without inline asm.
        Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and I128toV2I64. The first one moves a register, the second one moves two 64-bit registers into one 128-bit register, and the third one just does the opposite.
  • NVPTXISelLowering & NVPTXISelDAGToDAG:
        Custom lowering CopyToReg and CopyFromReg with 128-bit operands. CopyToReg deals with the inputs of the inline asm and the CopyFromReg deals with the outputs.
        CopyToReg is custom lowered into a V2I64toI128, which takes in the expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.
        CopyFromReg is custom lowered by adding a I128toV2I64, which breaks down the 128-bit outputs of inline asm into the expanded values.

Patch is 31.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97113.diff

14 Files Affected:

  • (modified) clang/lib/Basic/Targets/NVPTX.h (+1)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+68)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+2-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+88)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+8)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+10-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (+4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+3)
  • (added) llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll (+148)
  • (added) llvm/test/CodeGen/NVPTX/inline-asm-b128-test2.ll (+122)
  • (added) llvm/test/CodeGen/NVPTX/inline-asm-b128-test3.ll (+67)
diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h
index f476d49047c01..7e9b6b34df636 100644
--- a/clang/lib/Basic/Targets/NVPTX.h
+++ b/clang/lib/Basic/Targets/NVPTX.h
@@ -105,6 +105,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
     case 'l':
     case 'f':
     case 'd':
+    case 'q':
       Info.setAllowsRegister();
       return true;
     }
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index b7a20c351f5ff..380d878c1f532 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -60,6 +60,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) const {
   case 6:
     OS << "%fd";
     break;
+  case 7:
+    OS << "%rq";
+    break;
   }
 
   unsigned VReg = Reg.id() & 0x0FFFFFFF;
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index ca077d41d36ba..1645261d74d06 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -315,6 +315,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
       Ret = (5 << 28);
     } else if (RC == &NVPTX::Float64RegsRegClass) {
       Ret = (6 << 28);
+    } else if (RC == &NVPTX::Int128RegsRegClass) {
+      Ret = (7 << 28);
     } else {
       report_fatal_error("Bad register class");
     }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 1e1cbb15e33d4..11193c11ede3b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -519,6 +519,20 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     if (tryConstantFP(N))
       return;
     break;
+  case ISD::CopyToReg: {
+    if (N->getOperand(1).getValueType() == MVT::i128) {
+      SelectV2I64toI128(N);
+      return;
+    }
+    break;
+  }
+  case ISD::CopyFromReg: {
+    if (N->getOperand(1).getValueType() == MVT::i128) {
+      SelectI128toV2I64(N);
+      return;
+    }
+    break;
+  }
   default:
     break;
   }
@@ -3798,6 +3812,60 @@ bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
   return true;
 }
 
+void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) {
+  // Lower a CopyToReg with two 64-bit inputs
+  // Dst:i128, lo:i64, hi:i64
+  //
+  // CopyToReg Dst, lo, hi;
+  //
+  // ==>
+  //
+  // tmp = V2I64toI128 {lo, hi};
+  // CopyToReg Dst, tmp;
+  SDValue Dst = N->getOperand(1);
+  SDValue Lo = N->getOperand(2);
+  SDValue Hi = N->getOperand(3);
+
+  SDLoc DL(N);
+  SDNode *Mov =
+      CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi});
+
+  SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1);
+  NewOps[0] = N->getOperand(0);
+  NewOps[1] = Dst;
+  NewOps[2] = SDValue(Mov, 0);
+  if (N->getNumOperands() == 5)
+    NewOps[3] = N->getOperand(4);
+  SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps);
+
+  ReplaceNode(N, NewValue.getNode());
+}
+
+void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
+  // Lower CopyFromReg from a 128-bit regs to two 64-bit regs
+  // Dst:i128, Src:i128
+  //
+  // {lo, hi} = CopyFromReg Src
+  //
+  // ==>
+  //
+  // {lo, hi} = I128toV2I64 Src
+  //
+  SDValue Ch = N->getOperand(0);
+  SDValue Src = N->getOperand(1);
+  SDValue Glue = N->getOperand(2);
+  SDLoc DL(N);
+
+  // Add Glue and Ch to the operands and results to avoid break the execution
+  // order
+  SDNode *Mov = CurDAG->getMachineNode(
+      NVPTX::I128toV2I64, DL,
+      {MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()},
+      {Src, Ch, Glue});
+
+  ReplaceNode(N, Mov);
+}
+
 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
 /// conversion from \p SrcTy to \p DestTy.
 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c5524351f2ff9..49626d4051485 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -74,7 +74,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool SelectSETP_F16X2(SDNode *N);
   bool SelectSETP_BF16X2(SDNode *N);
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
-
+  void SelectV2I64toI128(SDNode *N);
+  void SelectI128toV2I64(SDNode *N);
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 76633a437fe71..c02d874a9a6b3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -859,6 +859,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
   }
 
+  // Custom lowering for inline asm with 128-bit operands
+  setOperationAction(ISD::CopyToReg, MVT::i128, Custom);
+  setOperationAction(ISD::CopyFromReg, MVT::i128, Custom);
+
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
   // No FPOW or FREM in PTX.
 
@@ -2804,6 +2808,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVectorArith(Op, DAG);
   case ISD::DYNAMIC_STACKALLOC:
     return LowerDYNAMIC_STACKALLOC(Op, DAG);
+  case ISD::CopyToReg:
+    return LowerCopyToReg_128(Op, DAG);
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
@@ -3094,6 +3100,54 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
   return Result;
 }
 
+SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
+                                                SelectionDAG &DAG) const {
+  // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
+  // operand so that it can pass the legalization.
+
+  assert(Op.getOperand(1).getValueType() == MVT::i128 &&
+         "Custom lowering for 128-bit CopyToReg only");
+
+  SDNode *Node = Op.getNode();
+  SDLoc DL(Node);
+
+  SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
+  SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
+                           DAG.getIntPtrConstant(0, DL));
+  SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
+                           DAG.getIntPtrConstant(1, DL));
+
+  SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
+  SmallVector<EVT, 3> ResultsType(Node->values());
+
+  NewOps[0] = Op->getOperand(0); // Chain
+  NewOps[1] = Op->getOperand(1); // Dst Reg
+  NewOps[2] = Lo;                // Lower 64-bit
+  NewOps[3] = Hi;                // Higher 64-bit
+  if (Op.getNumOperands() == 4)
+    NewOps[4] = Op->getOperand(3); // Glue if exists
+
+  return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
+}
+
+unsigned NVPTXTargetLowering::getNumRegisters(
+    LLVMContext &Context, EVT VT,
+    std::optional<MVT> RegisterVT = std::nullopt) const {
+  if (VT == MVT::i128 && RegisterVT == MVT::i128)
+    return 1;
+  return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
+}
+
+bool NVPTXTargetLowering::splitValueIntoRegisterParts(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
+    unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
+  if (Val.getValueType() == MVT::i128 && NumParts == 1) {
+    Parts[0] = Val;
+    return true;
+  }
+  return false;
+}
+
 // This creates target external symbol for a function parameter.
 // Name of the symbol is composed from its index and the function name.
 // Negative index corresponds to special parameter (unsized array) used for
@@ -5152,6 +5206,7 @@ NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
     case 'l':
     case 'f':
     case 'd':
+    case 'q':
     case '0':
     case 'N':
       return C_RegisterClass;
@@ -5177,6 +5232,12 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
     case 'l':
     case 'N':
       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
+    case 'q': {
+      if (STI.getSmVersion() < 70)
+        report_fatal_error("Inline asm with 128 bit operands is only "
+                           "supported for sm_70 and higher!");
+      return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
+    }
     case 'f':
       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
     case 'd':
@@ -6244,6 +6305,30 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
   }
 }
 
+static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
+                                   SmallVectorImpl<SDValue> &Results) {
+  // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
+  // result so that it can pass the legalization
+  SDLoc DL(N);
+  SDValue Chain = N->getOperand(0);
+  SDValue Reg = N->getOperand(1);
+  SDValue Glue = N->getOperand(2);
+
+  assert(Reg.getValueType() == MVT::i128 &&
+         "Custom lowering for CopyFromReg with 128-bit reg only");
+  SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
+                                     N->getValueType(2)};
+  SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
+
+  SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
+  SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
+                             {NewValue.getValue(0), NewValue.getValue(1)});
+
+  Results.push_back(Pair);
+  Results.push_back(NewValue.getValue(2));
+  Results.push_back(NewValue.getValue(3));
+}
+
 void NVPTXTargetLowering::ReplaceNodeResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   switch (N->getOpcode()) {
@@ -6255,6 +6340,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   case ISD::INTRINSIC_W_CHAIN:
     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
     return;
+  case ISD::CopyFromReg:
+    ReplaceCopyFromReg_128(N, DAG, Results);
+    return;
   }
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index e211286fcc556..63262961b363e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -640,6 +640,14 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
+  unsigned getNumRegisters(LLVMContext &Context, EVT VT,
+                           std::optional<MVT> RegisterVT) const override;
+  bool
+  splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
+                              SDValue *Parts, unsigned NumParts, MVT PartVT,
+                              std::optional<CallingConv::ID> CC) const override;
+
   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
                           SelectionDAG &DAG) const override;
   SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index b0d792b5ee3fe..673858f92e7ce 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -51,6 +51,8 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   } else if (DestRC == &NVPTX::Int64RegsRegClass) {
     Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr
                                              : NVPTX::BITCONVERT_64_F2I);
+  } else if (DestRC == &NVPTX::Int128RegsRegClass) {
+    Op = NVPTX::IMOV128rr;
   } else if (DestRC == &NVPTX::Float32RegsRegClass) {
     Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
                                                : NVPTX::BITCONVERT_32_I2F);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index c4c35a1f74ba9..827febe845a4c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2097,6 +2097,8 @@ let IsSimpleMove=1, hasSideEffects=0 in {
                            "mov.u32 \t$dst, $sss;", []>;
   def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss),
                            "mov.u64 \t$dst, $sss;", []>;
+  def IMOV128rr : NVPTXInst<(outs Int128Regs:$dst), (ins Int128Regs:$sss),
+                           "mov.b128 \t$dst, $sss;", []>;
 
   def IMOVB16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$sss),
                            "mov.b16 \t$dst, $sss;", []>;
@@ -3545,6 +3547,9 @@ let hasSideEffects = false in {
   def V2I32toI64 : NVPTXInst<(outs Int64Regs:$d),
                              (ins Int32Regs:$s1, Int32Regs:$s2),
                              "mov.b64 \t$d, {{$s1, $s2}};", []>;
+  def V2I64toI128 : NVPTXInst<(outs Int128Regs:$d),
+                              (ins Int64Regs:$s1, Int64Regs:$s2),
+                              "mov.b128 \t$d, {{$s1, $s2}};", []>;
   def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
                              (ins Float32Regs:$s1, Float32Regs:$s2),
                              "mov.b64 \t$d, {{$s1, $s2}};", []>;
@@ -3560,6 +3565,9 @@ let hasSideEffects = false in {
   def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
                              (ins Int64Regs:$s),
                              "mov.b64 \t{{$d1, $d2}}, $s;", []>;
+  def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
+                              (ins Int128Regs:$s),
+                              "mov.b128 \t{{$d1, $d2}}, $s;", []>;
   def F64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
                              (ins Float64Regs:$s),
                              "mov.b64 \t{{$d1, $d2}}, $s;", []>;
@@ -3629,7 +3637,7 @@ def : Pat<(i32 (ctlz (i32 Int32Regs:$a))), (CLZr32 Int32Regs:$a)>;
 // ptx value to 64 bits to match the ISD node's semantics, unless we know we're
 // truncating back down to 32 bits.
 def : Pat<(i64 (ctlz Int64Regs:$a)), (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>;
-def : Pat<(i32 (trunc (ctlz Int64Regs:$a))), (CLZr64 Int64Regs:$a)>;
+def : Pat<(i32 (trunc (i64 (ctlz Int64Regs:$a)))), (CLZr64 Int64Regs:$a)>;
 
 // For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the
 // result back to 16-bits if necessary.  We also need to subtract 16 because
@@ -3667,7 +3675,7 @@ def : Pat<(i32 (ctpop (i32 Int32Regs:$a))), (POPCr32 Int32Regs:$a)>;
 // pattern that avoids the type conversion if we're truncating the result to
 // i32 anyway.
 def : Pat<(ctpop Int64Regs:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>;
-def : Pat<(i32 (trunc (ctpop Int64Regs:$a))), (POPCr64 Int64Regs:$a)>;
+def : Pat<(i32 (trunc (i64 (ctpop Int64Regs:$a)))), (POPCr64 Int64Regs:$a)>;
 
 // For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits.
 // If we know that we're storing into an i32, we can avoid the final trunc.
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
index f1213f030bba7..a8a23f04c1249 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -31,6 +31,8 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
     return ".f32";
   if (RC == &NVPTX::Float64RegsRegClass)
     return ".f64";
+  if (RC == &NVPTX::Int128RegsRegClass)
+    return ".b128";
   if (RC == &NVPTX::Int64RegsRegClass)
     // We use untyped (.b) integer registers here as NVCC does.
     // Correctness of generated code does not depend on register type,
@@ -67,6 +69,8 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
     return "%f";
   if (RC == &NVPTX::Float64RegsRegClass)
     return "%fd";
+  if (RC == &NVPTX::Int128RegsRegClass)
+    return "%rq";
   if (RC == &NVPTX::Int64RegsRegClass)
     return "%rd";
   if (RC == &NVPTX::Int32RegsRegClass)
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
index b5231a9cf67f9..2011f0f7e328f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -37,6 +37,7 @@ foreach i = 0...4 in {
   def RS#i : NVPTXReg<"%rs"#i>; // 16-bit
   def R#i  : NVPTXReg<"%r"#i>;  // 32-bit
   def RL#i : NVPTXReg<"%rd"#i>; // 64-bit
+  def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
   def H#i  : NVPTXReg<"%h"#i>;  // 16-bit float
   def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
   def F#i  : NVPTXReg<"%f"#i>;  // 32-bit float
@@ -62,6 +63,8 @@ def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
                               (add (sequence "R%u", 0, 4),
                               VRFrame32, VRFrameLocal32)>;
 def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
+// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
+def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
 def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
 def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
 def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
diff --git a/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll b/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
new file mode 100644
index 0000000000000..3232f40a40a70
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
@@ -0,0 +1,148 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --version 5
+; RUN: llc < %s -march=nvptx -mcpu=sm_70 -mattr=+ptx83 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_70 -mattr=+ptx83 | %ptxas-verify -arch=sm_70 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@value = internal addrspace(1) global i128 0, align 16
+
+define void @test_b128_input_from_const() {
+; CHECK-LABEL: test_b128_input_from_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b128 %rq<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.u64 %rd2, 0;
+; CHECK-NEXT:    mov.u64 %rd3, 42;
+; CHECK-NEXT:    mov.b128 %rq1, {%rd3, %rd2};
+; CHECK-NEXT:    mov.u32 %r1, value;
+; CHECK-NEXT:    cvta.global.u32 %r2, %r1;
+; CHECK-NEXT:    cvt.u64.u32 %rd1, %r2;
+; CHECK-NEXT:    // begin inline asm
+; CHECK-NEXT:    { st.b128 [%rd1], %rq1; }
+; CHECK-NEXT:    // end inline asm
+; CHECK-NEXT:    ret;
+
+  tail call void asm sideeffect "{ st.b128 [$0], $1; }", "l,q"(ptr nonnull addrspacecast (ptr addrspace(1) @value to ptr), i128 42)
+  ret void
+}
+
+define void @test_b128_input_from_load(ptr nocapture readonly %data) {
+; CHECK-LABEL: test_b128_input_from_load(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b128 %rq<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_b128_input_from_load_param_0];
+; CHECK-NEXT:    cvta.to.global.u32 %r2, %r1;
+; CHECK-NEXT:    ld.global.u64 %rd2, [%r2+8];
+; CHECK-NEXT:    ld.global.u64 %rd3, [%r2];
+; CHECK-NEXT:    mov.b128 %rq1, {%rd3, %rd2};
+; CHECK-NEXT:    mov.u32 %r3, value;
+; CHECK-NEXT:    cvta.global.u32 %r4, %r3;
+; CHECK-NEXT:    cvt.u64.u32 %rd1, %r4;
+; CHECK-NEXT:    // begin inline asm
+; CHECK-NEXT:    { st.b128 [%rd1], %rq1; }
+; CHECK-NEXT:    // end inline asm
+; CHECK-NEXT:    ret;
+
+  %1 = addrspacecast ptr %data to ptr addrspace(1)
+  %2 = load <2 x i64>, ptr addrspace(1) %1, align 16
+  %3 = bitcast <2 x i64> %2 to i128
+  tail call void asm sideeffect "{ st.b128 [$0], $1; }", "l,q"(ptr nonnull addrspacecast (ptr addrspace(1) @value to ptr), i128 %3)
+  ret void
+}
+
+define void @test_b128_input_from_select(ptr nocapture readonly %flag) {
+; CHECK-LABEL: test_b128_input_from_select(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b128 %rq<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_b128_input_from_select_param_0];
+; CHECK-NEXT:    cvta.to.glo...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-clang

Author: None (Chengjunp)

Changes

This change supports the 128-bit operands for inline ptx asm, both input and output.

The major changes are:

  • Tablegen:
        Define Int128Regs in NVPTXRegisterInfo.td. But this register does not set as general register type in NVPTX backend so that this change will not influence the codegen without inline asm.
        Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and I128toV2I64. The first one moves a register, the second one moves two 64-bit registers into one 128-bit register, and the third one just does the opposite.
  • NVPTXISelLowering & NVPTXISelDAGToDAG:
        Custom lowering CopyToReg and CopyFromReg with 128-bit operands. CopyToReg deals with the inputs of the inline asm and the CopyFromReg deals with the outputs.
        CopyToReg is custom lowered into a V2I64toI128, which takes in the expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.
        CopyFromReg is custom lowered by adding a I128toV2I64, which breaks down the 128-bit outputs of inline asm into the expanded values.

Patch is 31.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97113.diff

14 Files Affected:

  • (modified) clang/lib/Basic/Targets/NVPTX.h (+1)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+68)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+2-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+88)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+8)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+10-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (+4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+3)
  • (added) llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll (+148)
  • (added) llvm/test/CodeGen/NVPTX/inline-asm-b128-test2.ll (+122)
  • (added) llvm/test/CodeGen/NVPTX/inline-asm-b128-test3.ll (+67)
diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h
index f476d49047c01..7e9b6b34df636 100644
--- a/clang/lib/Basic/Targets/NVPTX.h
+++ b/clang/lib/Basic/Targets/NVPTX.h
@@ -105,6 +105,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXTargetInfo : public TargetInfo {
     case 'l':
     case 'f':
     case 'd':
+    case 'q':
       Info.setAllowsRegister();
       return true;
     }
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index b7a20c351f5ff..380d878c1f532 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -60,6 +60,9 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) const {
   case 6:
     OS << "%fd";
     break;
+  case 7:
+    OS << "%rq";
+    break;
   }
 
   unsigned VReg = Reg.id() & 0x0FFFFFFF;
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index ca077d41d36ba..1645261d74d06 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -315,6 +315,8 @@ unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
       Ret = (5 << 28);
     } else if (RC == &NVPTX::Float64RegsRegClass) {
       Ret = (6 << 28);
+    } else if (RC == &NVPTX::Int128RegsRegClass) {
+      Ret = (7 << 28);
     } else {
       report_fatal_error("Bad register class");
     }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 1e1cbb15e33d4..11193c11ede3b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -519,6 +519,20 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     if (tryConstantFP(N))
       return;
     break;
+  case ISD::CopyToReg: {
+    if (N->getOperand(1).getValueType() == MVT::i128) {
+      SelectV2I64toI128(N);
+      return;
+    }
+    break;
+  }
+  case ISD::CopyFromReg: {
+    if (N->getOperand(1).getValueType() == MVT::i128) {
+      SelectI128toV2I64(N);
+      return;
+    }
+    break;
+  }
   default:
     break;
   }
@@ -3798,6 +3812,60 @@ bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
   return true;
 }
 
+void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) {
+  // Lower a CopyToReg with two 64-bit inputs
+  // Dst:i128, lo:i64, hi:i64
+  //
+  // CopyToReg Dst, lo, hi;
+  //
+  // ==>
+  //
+  // tmp = V2I64toI128 {lo, hi};
+  // CopyToReg Dst, tmp;
+  SDValue Dst = N->getOperand(1);
+  SDValue Lo = N->getOperand(2);
+  SDValue Hi = N->getOperand(3);
+
+  SDLoc DL(N);
+  SDNode *Mov =
+      CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi});
+
+  SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1);
+  NewOps[0] = N->getOperand(0);
+  NewOps[1] = Dst;
+  NewOps[2] = SDValue(Mov, 0);
+  if (N->getNumOperands() == 5)
+    NewOps[3] = N->getOperand(4);
+  SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps);
+
+  ReplaceNode(N, NewValue.getNode());
+}
+
+void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
+  // Lower CopyFromReg from a 128-bit regs to two 64-bit regs
+  // Dst:i128, Src:i128
+  //
+  // {lo, hi} = CopyFromReg Src
+  //
+  // ==>
+  //
+  // {lo, hi} = I128toV2I64 Src
+  //
+  SDValue Ch = N->getOperand(0);
+  SDValue Src = N->getOperand(1);
+  SDValue Glue = N->getOperand(2);
+  SDLoc DL(N);
+
+  // Add Glue and Ch to the operands and results to avoid break the execution
+  // order
+  SDNode *Mov = CurDAG->getMachineNode(
+      NVPTX::I128toV2I64, DL,
+      {MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()},
+      {Src, Ch, Glue});
+
+  ReplaceNode(N, Mov);
+}
+
 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
 /// conversion from \p SrcTy to \p DestTy.
 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c5524351f2ff9..49626d4051485 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -74,7 +74,8 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool SelectSETP_F16X2(SDNode *N);
   bool SelectSETP_BF16X2(SDNode *N);
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
-
+  void SelectV2I64toI128(SDNode *N);
+  void SelectI128toV2I64(SDNode *N);
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
   }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 76633a437fe71..c02d874a9a6b3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -859,6 +859,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
   }
 
+  // Custom lowering for inline asm with 128-bit operands
+  setOperationAction(ISD::CopyToReg, MVT::i128, Custom);
+  setOperationAction(ISD::CopyFromReg, MVT::i128, Custom);
+
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
   // No FPOW or FREM in PTX.
 
@@ -2804,6 +2808,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVectorArith(Op, DAG);
   case ISD::DYNAMIC_STACKALLOC:
     return LowerDYNAMIC_STACKALLOC(Op, DAG);
+  case ISD::CopyToReg:
+    return LowerCopyToReg_128(Op, DAG);
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
@@ -3094,6 +3100,54 @@ SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
   return Result;
 }
 
+SDValue NVPTXTargetLowering::LowerCopyToReg_128(SDValue Op,
+                                                SelectionDAG &DAG) const {
+  // Change the CopyToReg to take in two 64-bit operands instead of a 128-bit
+  // operand so that it can pass the legalization.
+
+  assert(Op.getOperand(1).getValueType() == MVT::i128 &&
+         "Custom lowering for 128-bit CopyToReg only");
+
+  SDNode *Node = Op.getNode();
+  SDLoc DL(Node);
+
+  SDValue Cast = DAG.getBitcast(MVT::v2i64, Op->getOperand(2));
+  SDValue Lo = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
+                           DAG.getIntPtrConstant(0, DL));
+  SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64, Cast,
+                           DAG.getIntPtrConstant(1, DL));
+
+  SmallVector<SDValue, 5> NewOps(Op->getNumOperands() + 1);
+  SmallVector<EVT, 3> ResultsType(Node->values());
+
+  NewOps[0] = Op->getOperand(0); // Chain
+  NewOps[1] = Op->getOperand(1); // Dst Reg
+  NewOps[2] = Lo;                // Lower 64-bit
+  NewOps[3] = Hi;                // Higher 64-bit
+  if (Op.getNumOperands() == 4)
+    NewOps[4] = Op->getOperand(3); // Glue if exists
+
+  return DAG.getNode(ISD::CopyToReg, DL, ResultsType, NewOps);
+}
+
+unsigned NVPTXTargetLowering::getNumRegisters(
+    LLVMContext &Context, EVT VT,
+    std::optional<MVT> RegisterVT = std::nullopt) const {
+  if (VT == MVT::i128 && RegisterVT == MVT::i128)
+    return 1;
+  return TargetLoweringBase::getNumRegisters(Context, VT, RegisterVT);
+}
+
+bool NVPTXTargetLowering::splitValueIntoRegisterParts(
+    SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
+    unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
+  if (Val.getValueType() == MVT::i128 && NumParts == 1) {
+    Parts[0] = Val;
+    return true;
+  }
+  return false;
+}
+
 // This creates target external symbol for a function parameter.
 // Name of the symbol is composed from its index and the function name.
 // Negative index corresponds to special parameter (unsized array) used for
@@ -5152,6 +5206,7 @@ NVPTXTargetLowering::getConstraintType(StringRef Constraint) const {
     case 'l':
     case 'f':
     case 'd':
+    case 'q':
     case '0':
     case 'N':
       return C_RegisterClass;
@@ -5177,6 +5232,12 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
     case 'l':
     case 'N':
       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
+    case 'q': {
+      if (STI.getSmVersion() < 70)
+        report_fatal_error("Inline asm with 128 bit operands is only "
+                           "supported for sm_70 and higher!");
+      return std::make_pair(0U, &NVPTX::Int128RegsRegClass);
+    }
     case 'f':
       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
     case 'd':
@@ -6244,6 +6305,30 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
   }
 }
 
+static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
+                                   SmallVectorImpl<SDValue> &Results) {
+  // Change the CopyFromReg to output 2 64-bit results instead of a 128-bit
+  // result so that it can pass the legalization
+  SDLoc DL(N);
+  SDValue Chain = N->getOperand(0);
+  SDValue Reg = N->getOperand(1);
+  SDValue Glue = N->getOperand(2);
+
+  assert(Reg.getValueType() == MVT::i128 &&
+         "Custom lowering for CopyFromReg with 128-bit reg only");
+  SmallVector<EVT, 4> ResultsType = {MVT::i64, MVT::i64, N->getValueType(1),
+                                     N->getValueType(2)};
+  SmallVector<SDValue, 3> NewOps = {Chain, Reg, Glue};
+
+  SDValue NewValue = DAG.getNode(ISD::CopyFromReg, DL, ResultsType, NewOps);
+  SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
+                             {NewValue.getValue(0), NewValue.getValue(1)});
+
+  Results.push_back(Pair);
+  Results.push_back(NewValue.getValue(2));
+  Results.push_back(NewValue.getValue(3));
+}
+
 void NVPTXTargetLowering::ReplaceNodeResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   switch (N->getOpcode()) {
@@ -6255,6 +6340,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   case ISD::INTRINSIC_W_CHAIN:
     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
     return;
+  case ISD::CopyFromReg:
+    ReplaceCopyFromReg_128(N, DAG, Results);
+    return;
   }
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index e211286fcc556..63262961b363e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -640,6 +640,14 @@ class NVPTXTargetLowering : public TargetLowering {
   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
+  unsigned getNumRegisters(LLVMContext &Context, EVT VT,
+                           std::optional<MVT> RegisterVT) const override;
+  bool
+  splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
+                              SDValue *Parts, unsigned NumParts, MVT PartVT,
+                              std::optional<CallingConv::ID> CC) const override;
+
   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
                           SelectionDAG &DAG) const override;
   SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index b0d792b5ee3fe..673858f92e7ce 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -51,6 +51,8 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
   } else if (DestRC == &NVPTX::Int64RegsRegClass) {
     Op = (SrcRC == &NVPTX::Int64RegsRegClass ? NVPTX::IMOV64rr
                                              : NVPTX::BITCONVERT_64_F2I);
+  } else if (DestRC == &NVPTX::Int128RegsRegClass) {
+    Op = NVPTX::IMOV128rr;
   } else if (DestRC == &NVPTX::Float32RegsRegClass) {
     Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
                                                : NVPTX::BITCONVERT_32_I2F);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index c4c35a1f74ba9..827febe845a4c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -2097,6 +2097,8 @@ let IsSimpleMove=1, hasSideEffects=0 in {
                            "mov.u32 \t$dst, $sss;", []>;
   def IMOV64rr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$sss),
                            "mov.u64 \t$dst, $sss;", []>;
+  def IMOV128rr : NVPTXInst<(outs Int128Regs:$dst), (ins Int128Regs:$sss),
+                           "mov.b128 \t$dst, $sss;", []>;
 
   def IMOVB16rr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$sss),
                            "mov.b16 \t$dst, $sss;", []>;
@@ -3545,6 +3547,9 @@ let hasSideEffects = false in {
   def V2I32toI64 : NVPTXInst<(outs Int64Regs:$d),
                              (ins Int32Regs:$s1, Int32Regs:$s2),
                              "mov.b64 \t$d, {{$s1, $s2}};", []>;
+  def V2I64toI128 : NVPTXInst<(outs Int128Regs:$d),
+                              (ins Int64Regs:$s1, Int64Regs:$s2),
+                              "mov.b128 \t$d, {{$s1, $s2}};", []>;
   def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
                              (ins Float32Regs:$s1, Float32Regs:$s2),
                              "mov.b64 \t$d, {{$s1, $s2}};", []>;
@@ -3560,6 +3565,9 @@ let hasSideEffects = false in {
   def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
                              (ins Int64Regs:$s),
                              "mov.b64 \t{{$d1, $d2}}, $s;", []>;
+  def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
+                              (ins Int128Regs:$s),
+                              "mov.b128 \t{{$d1, $d2}}, $s;", []>;
   def F64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
                              (ins Float64Regs:$s),
                              "mov.b64 \t{{$d1, $d2}}, $s;", []>;
@@ -3629,7 +3637,7 @@ def : Pat<(i32 (ctlz (i32 Int32Regs:$a))), (CLZr32 Int32Regs:$a)>;
 // ptx value to 64 bits to match the ISD node's semantics, unless we know we're
 // truncating back down to 32 bits.
 def : Pat<(i64 (ctlz Int64Regs:$a)), (CVT_u64_u32 (CLZr64 Int64Regs:$a), CvtNONE)>;
-def : Pat<(i32 (trunc (ctlz Int64Regs:$a))), (CLZr64 Int64Regs:$a)>;
+def : Pat<(i32 (trunc (i64 (ctlz Int64Regs:$a)))), (CLZr64 Int64Regs:$a)>;
 
 // For 16-bit ctlz, we zero-extend to 32-bit, perform the count, then trunc the
 // result back to 16-bits if necessary.  We also need to subtract 16 because
@@ -3667,7 +3675,7 @@ def : Pat<(i32 (ctpop (i32 Int32Regs:$a))), (POPCr32 Int32Regs:$a)>;
 // pattern that avoids the type conversion if we're truncating the result to
 // i32 anyway.
 def : Pat<(ctpop Int64Regs:$a), (CVT_u64_u32 (POPCr64 Int64Regs:$a), CvtNONE)>;
-def : Pat<(i32 (trunc (ctpop Int64Regs:$a))), (POPCr64 Int64Regs:$a)>;
+def : Pat<(i32 (trunc (i64 (ctpop Int64Regs:$a)))), (POPCr64 Int64Regs:$a)>;
 
 // For 16-bit, we zero-extend to 32-bit, then trunc the result back to 16-bits.
 // If we know that we're storing into an i32, we can avoid the final trunc.
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
index f1213f030bba7..a8a23f04c1249 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -31,6 +31,8 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
     return ".f32";
   if (RC == &NVPTX::Float64RegsRegClass)
     return ".f64";
+  if (RC == &NVPTX::Int128RegsRegClass)
+    return ".b128";
   if (RC == &NVPTX::Int64RegsRegClass)
     // We use untyped (.b) integer registers here as NVCC does.
     // Correctness of generated code does not depend on register type,
@@ -67,6 +69,8 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
     return "%f";
   if (RC == &NVPTX::Float64RegsRegClass)
     return "%fd";
+  if (RC == &NVPTX::Int128RegsRegClass)
+    return "%rq";
   if (RC == &NVPTX::Int64RegsRegClass)
     return "%rd";
   if (RC == &NVPTX::Int32RegsRegClass)
diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
index b5231a9cf67f9..2011f0f7e328f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -37,6 +37,7 @@ foreach i = 0...4 in {
   def RS#i : NVPTXReg<"%rs"#i>; // 16-bit
   def R#i  : NVPTXReg<"%r"#i>;  // 32-bit
   def RL#i : NVPTXReg<"%rd"#i>; // 64-bit
+  def RQ#i : NVPTXReg<"%rq"#i>; // 128-bit
   def H#i  : NVPTXReg<"%h"#i>;  // 16-bit float
   def HH#i : NVPTXReg<"%hh"#i>; // 2x16-bit float
   def F#i  : NVPTXReg<"%f"#i>;  // 32-bit float
@@ -62,6 +63,8 @@ def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
                               (add (sequence "R%u", 0, 4),
                               VRFrame32, VRFrameLocal32)>;
 def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
+// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
+def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
 def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
 def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
 def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
diff --git a/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll b/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
new file mode 100644
index 0000000000000..3232f40a40a70
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll
@@ -0,0 +1,148 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --version 5
+; RUN: llc < %s -march=nvptx -mcpu=sm_70 -mattr=+ptx83 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_70 -mattr=+ptx83 | %ptxas-verify -arch=sm_70 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@value = internal addrspace(1) global i128 0, align 16
+
+define void @test_b128_input_from_const() {
+; CHECK-LABEL: test_b128_input_from_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b128 %rq<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.u64 %rd2, 0;
+; CHECK-NEXT:    mov.u64 %rd3, 42;
+; CHECK-NEXT:    mov.b128 %rq1, {%rd3, %rd2};
+; CHECK-NEXT:    mov.u32 %r1, value;
+; CHECK-NEXT:    cvta.global.u32 %r2, %r1;
+; CHECK-NEXT:    cvt.u64.u32 %rd1, %r2;
+; CHECK-NEXT:    // begin inline asm
+; CHECK-NEXT:    { st.b128 [%rd1], %rq1; }
+; CHECK-NEXT:    // end inline asm
+; CHECK-NEXT:    ret;
+
+  tail call void asm sideeffect "{ st.b128 [$0], $1; }", "l,q"(ptr nonnull addrspacecast (ptr addrspace(1) @value to ptr), i128 42)
+  ret void
+}
+
+define void @test_b128_input_from_load(ptr nocapture readonly %data) {
+; CHECK-LABEL: test_b128_input_from_load(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b128 %rq<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_b128_input_from_load_param_0];
+; CHECK-NEXT:    cvta.to.global.u32 %r2, %r1;
+; CHECK-NEXT:    ld.global.u64 %rd2, [%r2+8];
+; CHECK-NEXT:    ld.global.u64 %rd3, [%r2];
+; CHECK-NEXT:    mov.b128 %rq1, {%rd3, %rd2};
+; CHECK-NEXT:    mov.u32 %r3, value;
+; CHECK-NEXT:    cvta.global.u32 %r4, %r3;
+; CHECK-NEXT:    cvt.u64.u32 %rd1, %r4;
+; CHECK-NEXT:    // begin inline asm
+; CHECK-NEXT:    { st.b128 [%rd1], %rq1; }
+; CHECK-NEXT:    // end inline asm
+; CHECK-NEXT:    ret;
+
+  %1 = addrspacecast ptr %data to ptr addrspace(1)
+  %2 = load <2 x i64>, ptr addrspace(1) %1, align 16
+  %3 = bitcast <2 x i64> %2 to i128
+  tail call void asm sideeffect "{ st.b128 [$0], $1; }", "l,q"(ptr nonnull addrspacecast (ptr addrspace(1) @value to ptr), i128 %3)
+  ret void
+}
+
+define void @test_b128_input_from_select(ptr nocapture readonly %flag) {
+; CHECK-LABEL: test_b128_input_from_select(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-NEXT:    .reg .b128 %rq<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_b128_input_from_select_param_0];
+; CHECK-NEXT:    cvta.to.glo...
[truncated]

@AlexMaclean AlexMaclean requested a review from jlebar June 28, 2024 21:16
Copy link
Member

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be documented in the langref in this section, right? https://llvm.org/docs/LangRef.html#supported-constraint-code-list

Copy link
Member

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than the previous comment.

@Chengjunp
Copy link
Contributor Author

This needs to be documented in the langref in this section, right? https://llvm.org/docs/LangRef.html#supported-constraint-code-list
Oh, you are right. Thank you for pointing this out. Which file should I modify? Is it llvm/docs/LangRef.rst? Thanks!

@jlebar
Copy link
Member

jlebar commented Jun 28, 2024

Which file should I modify?

Use git grep to find where the text from that section of the langref lives?

@Chengjunp Chengjunp force-pushed the dev/chengjunp/inline_asm_b128_upstream branch from fe55642 to a047ab2 Compare July 1, 2024 21:31
@AlexMaclean AlexMaclean merged commit cbd3f25 into llvm:main Jul 1, 2024
4 of 5 checks passed
Copy link

github-actions bot commented Jul 1, 2024

@Chengjunp Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested
by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as
the builds can include changes from many authors. It is not uncommon for your
change to be included in a build that fails due to someone else's changes, or
infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself.
This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…m#97113)

This change supports the 128-bit operands for inline ptx asm, both input
and output.\
\
The major changes are:

- Tablegen:\
    Define Int128Regs in NVPTXRegisterInfo.td. But this register does
not set as general register type in NVPTX backend so that this change
will not influence the codegen without inline asm.\
    Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and
I128toV2I64. The first one moves a register, the second one moves two
64-bit registers into one 128-bit register, and the third one just does
the opposite.
- NVPTXISelLowering & NVPTXISelDAGToDAG:\
    Custom lowering CopyToReg and CopyFromReg with 128-bit operands.
CopyToReg deals with the inputs of the inline asm and the CopyFromReg
deals with the outputs.\
    CopyToReg is custom lowered into a V2I64toI128, which takes in the
expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.\
    CopyFromReg is custom lowered by adding a I128toV2I64, which breaks
down the 128-bit outputs of inline asm into the expanded values.
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
…m#97113)

This change supports the 128-bit operands for inline ptx asm, both input
and output.\
\
The major changes are:

- Tablegen:\
    Define Int128Regs in NVPTXRegisterInfo.td. But this register does
not set as general register type in NVPTX backend so that this change
will not influence the codegen without inline asm.\
    Define three NVPTX intrinsics, IMOV128rr, V2I64toI128 and
I128toV2I64. The first one moves a register, the second one moves two
64-bit registers into one 128-bit register, and the third one just does
the opposite.
- NVPTXISelLowering & NVPTXISelDAGToDAG:\
    Custom lowering CopyToReg and CopyFromReg with 128-bit operands.
CopyToReg deals with the inputs of the inline asm and the CopyFromReg
deals with the outputs.\
    CopyToReg is custom lowered into a V2I64toI128, which takes in the
expanded values(Lo and Hi) of the input, and moves into a 128-bit reg.\
    CopyFromReg is custom lowered by adding a I128toV2I64, which breaks
down the 128-bit outputs of inline asm into the expanded values.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants