diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 612433b54f6e44..1e607a1c714afc 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2010,6 +2010,11 @@ class TargetLoweringBase { virtual bool hasInlineStackProbe(const MachineFunction &MF) const { return false; } + virtual bool + isStackProbeInstrDefinedFlagRegLiveIn(const MachineBasicBlock *MBB) const { + return false; + } + virtual StringRef getStackProbeSymbolName(const MachineFunction &MF) const { return ""; } diff --git a/llvm/lib/CodeGen/ShrinkWrap.cpp b/llvm/lib/CodeGen/ShrinkWrap.cpp index ab57d08e527e4d..59f836e98a751c 100644 --- a/llvm/lib/CodeGen/ShrinkWrap.cpp +++ b/llvm/lib/CodeGen/ShrinkWrap.cpp @@ -957,6 +957,15 @@ bool ShrinkWrap::runOnMachineFunction(MachineFunction &MF) { if (!ArePointsInteresting()) return Changed; + const TargetLowering *TLI = MF.getSubtarget().getTargetLowering(); + while (TLI->hasInlineStackProbe(MF) && + TLI->isStackProbeInstrDefinedFlagRegLiveIn(Save)) { + Save = FindIDom<>(*Save, Save->predecessors(), *MDT); + Restore = FindIDom<>(*Restore, Restore->successors(), *MPDT); + if (!ArePointsInteresting()) + return Changed; + } + LLVM_DEBUG(dbgs() << "Final shrink wrap candidates:\nSave: " << printMBBReference(*Save) << ' ' << "\nRestore: " << printMBBReference(*Restore) << '\n'); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a3b7e3128ac1a4..ea7d955c49efc3 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -27349,3 +27349,8 @@ bool AArch64TargetLowering::hasInlineStackProbe( return !Subtarget->isTargetWindows() && MF.getInfo()->hasStackProbing(); } + +bool AArch64TargetLowering::isStackProbeInstrDefinedFlagRegLiveIn( + const MachineBasicBlock *MBB) const { + return MBB->isLiveIn(AArch64::NZCV); +} diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 436b21fd134632..b22573d1b5fb31 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -997,6 +997,9 @@ class AArch64TargetLowering : public TargetLowering { /// True if stack clash protection is enabled for this functions. bool hasInlineStackProbe(const MachineFunction &MF) const override; + bool isStackProbeInstrDefinedFlagRegLiveIn( + const MachineBasicBlock *MBB) const override; + private: /// Keep a pointer to the AArch64Subtarget around so that we can /// make the right decision when generating code for different targets. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 18f9871b2bd0c3..89bbd59806f830 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -57882,6 +57882,11 @@ X86TargetLowering::getStackProbeSize(const MachineFunction &MF) const { 4096); } +bool X86TargetLowering::isStackProbeInstrDefinedFlagRegLiveIn( + const MachineBasicBlock *MBB) const { + return MBB->isLiveIn(X86::EFLAGS); +} + Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const { if (ML && ML->isInnermost() && ExperimentalPrefInnermostLoopAlignment.getNumOccurrences()) diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index f93c54781846bf..575f9cf63cb5c2 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1550,6 +1550,9 @@ namespace llvm { bool hasInlineStackProbe(const MachineFunction &MF) const override; StringRef getStackProbeSymbolName(const MachineFunction &MF) const override; + bool isStackProbeInstrDefinedFlagRegLiveIn( + const MachineBasicBlock *MBB) const override; + unsigned getStackProbeSize(const MachineFunction &MF) const; bool hasVectorBlend() const override { return true; }