Skip to content

Commit

Permalink
[AArch64] Extend costs for fptoi.sat intrinsics.
Browse files Browse the repository at this point in the history
Most of these bring the costs in line with the code generation. The f16 costs
without FullFP16 are usually converted to f32. Extended v2f32->v2f64 vectors
similarly use fcvtl + fcvt. As a backup we use the costs similar to the target
independent code, which should give a relatively high cost.
  • Loading branch information
davemgreen authored and banach-space committed Aug 7, 2024
1 parent 7401ef0 commit ca50690
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 121 deletions.
76 changes: 61 additions & 15 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,22 +748,44 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
// output are the same, or we are using cvt f64->i32 or f32->i64.
if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
LT.second == MVT::v2f64) &&
(LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
(LT.second == MVT::f64 && MTy == MVT::i32) ||
(LT.second == MVT::f32 && MTy == MVT::i64)))
return LT.first;
// Similarly for fp16 sizes
if (ST->hasFullFP16() &&
((LT.second == MVT::f16 && MTy == MVT::i32) ||
((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
(LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
LT.second == MVT::v2f64)) {
if ((LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
(LT.second == MVT::f64 && MTy == MVT::i32) ||
(LT.second == MVT::f32 && MTy == MVT::i64)))
return LT.first;
// Extending vector types v2f32->v2i64, fcvtl*2 + fcvt*2
if (LT.second.getScalarType() == MVT::f32 && MTy.isFixedLengthVector() &&
MTy.getScalarSizeInBits() == 64)
return LT.first * (MTy.getVectorNumElements() > 2 ? 4 : 2);
}
// Similarly for fp16 sizes. Without FullFP16 we generally need to fcvt to
// f32.
if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
return LT.first + getIntrinsicInstrCost(
{ICA.getID(),
RetTy,
{ICA.getArgTypes()[0]->getWithNewType(
Type::getFloatTy(RetTy->getContext()))}},
CostKind);
if ((LT.second == MVT::f16 && MTy == MVT::i32) ||
(LT.second == MVT::f16 && MTy == MVT::i64) ||
((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
(LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))
return LT.first;

// Otherwise we use a legal convert followed by a min+max
// Extending vector types v8f16->v8i32, fcvtl*2 + fcvt*2
if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
MTy.getScalarSizeInBits() == 32)
return LT.first * (MTy.getVectorNumElements() > 4 ? 4 : 2);
// Extending vector types v8f16->v8i32. These current scalarize but the
// codegen could be better.
if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
MTy.getScalarSizeInBits() == 64)
return MTy.getVectorNumElements() * 3;

// If we can we use a legal convert followed by a min+max
if ((LT.second.getScalarType() == MVT::f32 ||
LT.second.getScalarType() == MVT::f64 ||
(ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
LT.second.getScalarType() == MVT::f16) &&
LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
Type *LegalTy =
Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
Expand All @@ -776,9 +798,33 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
LegalTy, {LegalTy, LegalTy});
Cost += getIntrinsicInstrCost(Attrs2, CostKind);
return LT.first * Cost;
return LT.first * Cost +
((LT.second.getScalarType() != MVT::f16 || ST->hasFullFP16()) ? 0
: 1);
}
break;
// Otherwise we need to follow the default expansion that clamps the value
// using a float min/max with a fcmp+sel for nan handling when signed.
Type *FPTy = ICA.getArgTypes()[0]->getScalarType();
RetTy = RetTy->getScalarType();
if (LT.second.isVector()) {
FPTy = VectorType::get(FPTy, LT.second.getVectorElementCount());
RetTy = VectorType::get(RetTy, LT.second.getVectorElementCount());
}
IntrinsicCostAttributes Attrs1(Intrinsic::minnum, FPTy, {FPTy, FPTy});
InstructionCost Cost = getIntrinsicInstrCost(Attrs1, CostKind);
IntrinsicCostAttributes Attrs2(Intrinsic::maxnum, FPTy, {FPTy, FPTy});
Cost += getIntrinsicInstrCost(Attrs2, CostKind);
Cost +=
getCastInstrCost(IsSigned ? Instruction::FPToSI : Instruction::FPToUI,
RetTy, FPTy, TTI::CastContextHint::None, CostKind);
if (IsSigned) {
Type *CondTy = RetTy->getWithNewBitWidth(1);
Cost += getCmpSelInstrCost(BinaryOperator::FCmp, FPTy, CondTy,
CmpInst::FCMP_UNO, CostKind);
Cost += getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy,
CmpInst::FCMP_UNO, CostKind);
}
return LT.first * Cost;
}
case Intrinsic::fshl:
case Intrinsic::fshr: {
Expand Down
Loading

0 comments on commit ca50690

Please sign in to comment.