Skip to content

Commit

Permalink
[RISCV] Lower f16/bf16 splat_vector by bitcasting to i16 instead of p…
Browse files Browse the repository at this point in the history
…romoting to f32. (#108298)

If f16/bf16 scalar types are not legal we also need to custom legalize
to prevent a crash. We do similar lowering for build_vector.
  • Loading branch information
topperc authored Sep 12, 2024
1 parent 35a0fd5 commit b2e8b8f
Show file tree
Hide file tree
Showing 34 changed files with 2,099 additions and 2,454 deletions.
42 changes: 22 additions & 20 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
VT, Custom);
if (Subtarget.hasStdExtZfhmin())
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
else
setOperationAction(ISD::SPLAT_VECTOR, MVT::f16, Custom);
// load/store
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);

Expand Down Expand Up @@ -1117,6 +1119,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
VT, Custom);
if (Subtarget.hasStdExtZfbfmin())
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
else
setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom);
setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom);

setOperationAction(ISD::FNEG, VT, Expand);
Expand Down Expand Up @@ -6988,30 +6992,28 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerVECTOR_SPLICE(Op, DAG);
case ISD::BUILD_VECTOR:
return lowerBUILD_VECTOR(Op, DAG, Subtarget);
case ISD::SPLAT_VECTOR:
if ((Op.getValueType().getScalarType() == MVT::f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
Subtarget.hasStdExtZfhminOrZhinxmin() &&
!Subtarget.hasVInstructionsF16())) ||
(Op.getValueType().getScalarType() == MVT::bf16 &&
(Subtarget.hasVInstructionsBF16Minimal() &&
Subtarget.hasStdExtZfbfmin()))) {
if (Op.getValueType() == MVT::nxv32f16 ||
Op.getValueType() == MVT::nxv32bf16)
return SplitVectorOp(Op, DAG);
case ISD::SPLAT_VECTOR: {
MVT VT = Op.getSimpleValueType();
MVT EltVT = VT.getVectorElementType();
if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) ||
EltVT == MVT::bf16) {
SDLoc DL(Op);
SDValue NewScalar =
DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
SDValue NewSplat = DAG.getNode(
ISD::SPLAT_VECTOR, DL,
MVT::getVectorVT(MVT::f32, Op.getValueType().getVectorElementCount()),
NewScalar);
return DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), NewSplat,
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
SDValue Elt;
if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
(EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin()))
Elt = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, Subtarget.getXLenVT(),
Op.getOperand(0));
else
Elt = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Op.getOperand(0));
MVT IVT = VT.changeVectorElementType(MVT::i16);
return DAG.getNode(ISD::BITCAST, DL, VT,
DAG.getNode(ISD::SPLAT_VECTOR, DL, IVT, Elt));
}
if (Op.getValueType().getVectorElementType() == MVT::i1)

if (EltVT == MVT::i1)
return lowerVectorMaskSplat(Op, DAG);
return SDValue();
}
case ISD::VECTOR_SHUFFLE:
return lowerVECTOR_SHUFFLE(Op, DAG, Subtarget);
case ISD::CONCAT_VECTORS: {
Expand Down
Loading

0 comments on commit b2e8b8f

Please sign in to comment.