Skip to content

Commit

Permalink
[PowerPC] Implement llvm.set.rounding intrinsic (#67302)
Browse files Browse the repository at this point in the history
  • Loading branch information
ecnelises authored Sep 10, 2024
1 parent ed0da00 commit 06c3311
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 5 deletions.
103 changes: 101 additions & 2 deletions llvm/lib/Target/PowerPC/PPCISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,13 @@ PPCTargetLowering::PPCTargetLowering(const PPCTargetMachine &TM,
} else {
setOperationAction(ISD::FMA , MVT::f64, Legal);
setOperationAction(ISD::FMA , MVT::f32, Legal);
setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
}

if (Subtarget.hasSPE())
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);

setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);

// If we're enabling GP optimizations, use hardware square root
if (!Subtarget.hasFSQRT() &&
!(TM.Options.UnsafeFPMath && Subtarget.hasFRSQRTE() &&
Expand Down Expand Up @@ -9060,6 +9060,103 @@ SDValue PPCTargetLowering::LowerINT_TO_FP(SDValue Op,
return FP;
}

SDValue PPCTargetLowering::LowerSET_ROUNDING(SDValue Op,
SelectionDAG &DAG) const {
SDLoc Dl(Op);
MachineFunction &MF = DAG.getMachineFunction();
EVT PtrVT = getPointerTy(MF.getDataLayout());
SDValue Chain = Op.getOperand(0);

// If requested mode is constant, just use simpler mtfsb/mffscrni
if (auto *CVal = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
uint64_t Mode = CVal->getZExtValue();
assert(Mode < 4 && "Unsupported rounding mode!");
unsigned InternalRnd = Mode ^ (~(Mode >> 1) & 1);
if (Subtarget.isISA3_0())
return SDValue(
DAG.getMachineNode(
PPC::MFFSCRNI, Dl, {MVT::f64, MVT::Other},
{DAG.getConstant(InternalRnd, Dl, MVT::i32, true), Chain}),
1);
SDNode *SetHi = DAG.getMachineNode(
(InternalRnd & 2) ? PPC::MTFSB1 : PPC::MTFSB0, Dl, MVT::Other,
{DAG.getConstant(30, Dl, MVT::i32, true), Chain});
SDNode *SetLo = DAG.getMachineNode(
(InternalRnd & 1) ? PPC::MTFSB1 : PPC::MTFSB0, Dl, MVT::Other,
{DAG.getConstant(31, Dl, MVT::i32, true), SDValue(SetHi, 0)});
return SDValue(SetLo, 0);
}

// Use x ^ (~(x >> 1) & 1) to transform LLVM rounding mode to Power format.
SDValue One = DAG.getConstant(1, Dl, MVT::i32);
SDValue SrcFlag = DAG.getNode(ISD::AND, Dl, MVT::i32, Op.getOperand(1),
DAG.getConstant(3, Dl, MVT::i32));
SDValue DstFlag = DAG.getNode(
ISD::XOR, Dl, MVT::i32, SrcFlag,
DAG.getNode(ISD::AND, Dl, MVT::i32,
DAG.getNOT(Dl,
DAG.getNode(ISD::SRL, Dl, MVT::i32, SrcFlag, One),
MVT::i32),
One));
// For Power9, there's faster mffscrn, and we don't need to read FPSCR
SDValue MFFS;
if (!Subtarget.isISA3_0()) {
MFFS = DAG.getNode(PPCISD::MFFS, Dl, {MVT::f64, MVT::Other}, Chain);
Chain = MFFS.getValue(1);
}
SDValue NewFPSCR;
if (Subtarget.isPPC64()) {
if (Subtarget.isISA3_0()) {
NewFPSCR = DAG.getAnyExtOrTrunc(DstFlag, Dl, MVT::i64);
} else {
// Set the last two bits (rounding mode) of bitcasted FPSCR.
SDNode *InsertRN = DAG.getMachineNode(
PPC::RLDIMI, Dl, MVT::i64,
{DAG.getNode(ISD::BITCAST, Dl, MVT::i64, MFFS),
DAG.getNode(ISD::ZERO_EXTEND, Dl, MVT::i64, DstFlag),
DAG.getTargetConstant(0, Dl, MVT::i32),
DAG.getTargetConstant(62, Dl, MVT::i32)});
NewFPSCR = SDValue(InsertRN, 0);
}
NewFPSCR = DAG.getNode(ISD::BITCAST, Dl, MVT::f64, NewFPSCR);
} else {
// In 32-bit mode, store f64, load and update the lower half.
int SSFI = MF.getFrameInfo().CreateStackObject(8, Align(8), false);
SDValue StackSlot = DAG.getFrameIndex(SSFI, PtrVT);
SDValue Addr = Subtarget.isLittleEndian()
? StackSlot
: DAG.getNode(ISD::ADD, Dl, PtrVT, StackSlot,
DAG.getConstant(4, Dl, PtrVT));
if (Subtarget.isISA3_0()) {
Chain = DAG.getStore(Chain, Dl, DstFlag, Addr, MachinePointerInfo());
} else {
Chain = DAG.getStore(Chain, Dl, MFFS, StackSlot, MachinePointerInfo());
SDValue Tmp =
DAG.getLoad(MVT::i32, Dl, Chain, Addr, MachinePointerInfo());
Chain = Tmp.getValue(1);
Tmp = SDValue(DAG.getMachineNode(
PPC::RLWIMI, Dl, MVT::i32,
{Tmp, DstFlag, DAG.getTargetConstant(0, Dl, MVT::i32),
DAG.getTargetConstant(30, Dl, MVT::i32),
DAG.getTargetConstant(31, Dl, MVT::i32)}),
0);
Chain = DAG.getStore(Chain, Dl, Tmp, Addr, MachinePointerInfo());
}
NewFPSCR =
DAG.getLoad(MVT::f64, Dl, Chain, StackSlot, MachinePointerInfo());
Chain = NewFPSCR.getValue(1);
}
if (Subtarget.isISA3_0())
return SDValue(DAG.getMachineNode(PPC::MFFSCRN, Dl, {MVT::f64, MVT::Other},
{NewFPSCR, Chain}),
1);
SDValue Zero = DAG.getConstant(0, Dl, MVT::i32, true);
SDNode *MTFSF = DAG.getMachineNode(
PPC::MTFSF, Dl, MVT::Other,
{DAG.getConstant(255, Dl, MVT::i32, true), NewFPSCR, Zero, Zero, Chain});
return SDValue(MTFSF, 0);
}

SDValue PPCTargetLowering::LowerGET_ROUNDING(SDValue Op,
SelectionDAG &DAG) const {
SDLoc dl(Op);
Expand Down Expand Up @@ -11921,6 +12018,8 @@ SDValue PPCTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::UINT_TO_FP:
case ISD::SINT_TO_FP: return LowerINT_TO_FP(Op, DAG);
case ISD::GET_ROUNDING: return LowerGET_ROUNDING(Op, DAG);
case ISD::SET_ROUNDING:
return LowerSET_ROUNDING(Op, DAG);

// Lower 64-bit shifts.
case ISD::SHL_PARTS: return LowerSHL_PARTS(Op, DAG);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/PowerPC/PPCISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ namespace llvm {
const SDLoc &dl) const;
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSHL_PARTS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSRL_PARTS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSRA_PARTS(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading

0 comments on commit 06c3311

Please sign in to comment.