Skip to content

Commit

Permalink
[Attributor] Keep track of reached returns in AAPointerInfo (llvm#107479
Browse files Browse the repository at this point in the history
)

Instead of visiting call sites in Attribute::checkForAllUses, we now
keep track of returns in AAPointerInfo and use the call site return
information as required. This way, the user of
AAPointerInfo(CallSite)Argument can determine if the call return should
be visited. We do not collect them as "may accesses" in the
AAPointerInfo(CallSite)Argument itself in case a return user is found.
  • Loading branch information
jdoerfert authored and VitaNuo committed Sep 12, 2024
1 parent e36a366 commit 25ece8b
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 42 deletions.
1 change: 1 addition & 0 deletions llvm/include/llvm/Transforms/IPO/Attributor.h
Original file line number Diff line number Diff line change
Expand Up @@ -6119,6 +6119,7 @@ struct AAPointerInfo : public AbstractAttribute {
virtual const_bin_iterator begin() const = 0;
virtual const_bin_iterator end() const = 0;
virtual int64_t numOffsetBins() const = 0;
virtual bool reachesReturn() const = 0;

/// Call \p CB on all accesses that might interfere with \p Range and return
/// true if all such accesses were known and the callback returned true for
Expand Down
16 changes: 0 additions & 16 deletions llvm/lib/Transforms/IPO/Attributor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1852,22 +1852,6 @@ bool Attributor::checkForAllUses(

User &Usr = *U->getUser();
AddUsers(Usr, /* OldUse */ nullptr);

auto *RI = dyn_cast<ReturnInst>(&Usr);
if (!RI)
continue;

Function &F = *RI->getFunction();
auto CallSitePred = [&](AbstractCallSite ACS) {
return AddUsers(*ACS.getInstruction(), U);
};
if (!checkForAllCallSites(CallSitePred, F, /* RequireAllCallSites */ true,
&QueryingAA, UsedAssumedInformation)) {
LLVM_DEBUG(dbgs() << "[Attributor] Could not follow return instruction "
"to all call sites: "
<< *RI << "\n");
return false;
}
}

return true;
Expand Down
59 changes: 49 additions & 10 deletions llvm/lib/Transforms/IPO/AttributorAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ struct AA::PointerInfo::State : public AbstractState {
AccessList = R.AccessList;
OffsetBins = R.OffsetBins;
RemoteIMap = R.RemoteIMap;
ReachesReturn = R.ReachesReturn;
return *this;
}

Expand All @@ -837,6 +838,7 @@ struct AA::PointerInfo::State : public AbstractState {
std::swap(AccessList, R.AccessList);
std::swap(OffsetBins, R.OffsetBins);
std::swap(RemoteIMap, R.RemoteIMap);
std::swap(ReachesReturn, R.ReachesReturn);
return *this;
}

Expand Down Expand Up @@ -878,11 +880,16 @@ struct AA::PointerInfo::State : public AbstractState {
AAPointerInfo::OffsetBinsTy OffsetBins;
DenseMap<const Instruction *, SmallVector<unsigned>> RemoteIMap;

/// Flag to determine if the underlying pointer is reaching a return statement
/// in the associated function or not. Returns in other functions cause
/// invalidation.
bool ReachesReturn = false;

/// See AAPointerInfo::forallInterferingAccesses.
bool forallInterferingAccesses(
AA::RangeTy Range,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const {
if (!isValidState())
if (!isValidState() || ReachesReturn)
return false;

for (const auto &It : OffsetBins) {
Expand All @@ -904,7 +911,7 @@ struct AA::PointerInfo::State : public AbstractState {
Instruction &I,
function_ref<bool(const AAPointerInfo::Access &, bool)> CB,
AA::RangeTy &Range) const {
if (!isValidState())
if (!isValidState() || ReachesReturn)
return false;

auto LocalList = RemoteIMap.find(&I);
Expand Down Expand Up @@ -1071,7 +1078,8 @@ struct AAPointerInfoImpl
return std::string("PointerInfo ") +
(isValidState() ? (std::string("#") +
std::to_string(OffsetBins.size()) + " bins")
: "<invalid>");
: "<invalid>") +
(ReachesReturn ? " (returned)" : "");
}

/// See AbstractAttribute::manifest(...).
Expand All @@ -1084,6 +1092,7 @@ struct AAPointerInfoImpl
virtual int64_t numOffsetBins() const override {
return State::numOffsetBins();
}
virtual bool reachesReturn() const override { return ReachesReturn; }

bool forallInterferingAccesses(
AA::RangeTy Range,
Expand Down Expand Up @@ -1373,6 +1382,7 @@ struct AAPointerInfoImpl

const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA);
bool IsByval = OtherAAImpl.getAssociatedArgument()->hasByValAttr();
ReachesReturn = OtherAAImpl.ReachesReturn;

// Combine the accesses bin by bin.
ChangeStatus Changed = ChangeStatus::UNCHANGED;
Expand Down Expand Up @@ -1666,8 +1676,13 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
}
if (isa<PtrToIntInst>(Usr))
return false;
if (isa<CastInst>(Usr) || isa<SelectInst>(Usr) || isa<ReturnInst>(Usr))
if (isa<CastInst>(Usr) || isa<SelectInst>(Usr))
return HandlePassthroughUser(Usr, CurPtr, Follow);
// Returns are allowed if they are in the associated functions. Users can
// then check the call site return. Returns from other functions can't be
// tracked and are cause for invalidation.
if (auto *RI = dyn_cast<ReturnInst>(Usr))
return ReachesReturn = RI->getFunction() == getAssociatedFunction();

// For PHIs we need to take care of the recurrence explicitly as the value
// might change while we iterate through a loop. For now, we give up if
Expand Down Expand Up @@ -1898,15 +1913,37 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) {
DepClassTy::REQUIRED);
if (!CSArgPI)
return false;
bool IsMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
bool IsArgMustAcc = (getUnderlyingObject(CurPtr) == &AssociatedValue);
Changed = translateAndAddState(A, *CSArgPI, OffsetInfoMap[CurPtr], *CB,
IsMustAcc) |
IsArgMustAcc) |
Changed;
if (!CSArgPI->reachesReturn())
return isValidState();

Function *Callee = CB->getCalledFunction();
if (!Callee || Callee->arg_size() <= ArgNo)
return false;
bool UsedAssumedInformation = false;
auto ReturnedValue = A.getAssumedSimplified(
IRPosition::returned(*Callee), *this, UsedAssumedInformation,
AA::ValueScope::Intraprocedural);
auto *ReturnedArg =
dyn_cast_or_null<Argument>(ReturnedValue.value_or(nullptr));
auto *Arg = Callee->getArg(ArgNo);
if (ReturnedArg && Arg != ReturnedArg)
return true;
bool IsRetMustAcc = IsArgMustAcc && (ReturnedArg == Arg);
const auto *CSRetPI = A.getAAFor<AAPointerInfo>(
*this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
if (!CSRetPI)
return false;
Changed = translateAndAddState(A, *CSRetPI, OffsetInfoMap[CurPtr], *CB,
IsRetMustAcc) |
Changed;
return isValidState();
}
LLVM_DEBUG(dbgs() << "[AAPointerInfo] Call user not handled " << *CB
<< "\n");
// TODO: Allow some call uses
return false;
}

Expand Down Expand Up @@ -2342,8 +2379,10 @@ struct AANoFreeFloating : AANoFreeImpl {
Follow = true;
return true;
}
if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI) ||
isa<ReturnInst>(UserI))
if (isa<StoreInst>(UserI) || isa<LoadInst>(UserI))
return true;

if (isa<ReturnInst>(UserI) && getIRPosition().isArgumentPosition())
return true;

// Unknown user.
Expand Down Expand Up @@ -12740,7 +12779,7 @@ struct AAAllocationInfoImpl : public AAAllocationInfo {
if (!PI)
return indicatePessimisticFixpoint();

if (!PI->getState().isValidState())
if (!PI->getState().isValidState() || PI->reachesReturn())
return indicatePessimisticFixpoint();

const DataLayout &DL = A.getDataLayout();
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/Attributor/IPConstantProp/pthreads.ll
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
define dso_local i32 @main() {
; TUNIT-LABEL: define {{[^@]+}}@main() {
; TUNIT-NEXT: entry:
; TUNIT-NEXT: [[ALLOC11:%.*]] = alloca i8, i32 0, align 8
; TUNIT-NEXT: [[ALLOC22:%.*]] = alloca i8, i32 0, align 8
; TUNIT-NEXT: [[ALLOC1:%.*]] = alloca i8, align 8
; TUNIT-NEXT: [[ALLOC2:%.*]] = alloca i8, align 8
; TUNIT-NEXT: [[THREAD:%.*]] = alloca i64, align 8
; TUNIT-NEXT: [[CALL:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @foo, ptr nofree readnone align 4294967296 undef)
; TUNIT-NEXT: [[CALL1:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @bar, ptr noalias nocapture nofree nonnull readnone align 8 dereferenceable(8) undef)
; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC11]])
; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC22]])
; TUNIT-NEXT: [[CALL2:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @baz, ptr noalias nocapture nofree noundef nonnull readnone align 8 dereferenceable(1) [[ALLOC1]])
; TUNIT-NEXT: [[CALL3:%.*]] = call i32 @pthread_create(ptr noundef nonnull align 8 dereferenceable(8) [[THREAD]], ptr noundef align 4294967296 null, ptr noundef nonnull @buz, ptr noalias nofree noundef nonnull readnone align 8 dereferenceable(1) "no-capture-maybe-returned" [[ALLOC2]])
; TUNIT-NEXT: ret i32 0
;
; CGSCC-LABEL: define {{[^@]+}}@main() {
Expand Down
15 changes: 3 additions & 12 deletions llvm/test/Transforms/Attributor/value-simplify-pointer-info.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3185,10 +3185,7 @@ define i32 @may_access_after_return(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
; TUNIT-NEXT: ret i32 [[ADD]]
; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return
Expand Down Expand Up @@ -3304,10 +3301,7 @@ define i32 @may_access_after_return_no_choice1(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]]) #[[ATTR18]]
; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
; TUNIT-NEXT: ret i32 [[ADD]]
; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice1
Expand Down Expand Up @@ -3342,10 +3336,7 @@ define i32 @may_access_after_return_no_choice2(i32 noundef %N, i32 noundef %M) {
; TUNIT-NEXT: [[A:%.*]] = alloca i32, align 4
; TUNIT-NEXT: [[B:%.*]] = alloca i32, align 4
; TUNIT-NEXT: call void @write_both(ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[B]], ptr nocapture nofree noundef nonnull writeonly align 4 dereferenceable(4) [[A]]) #[[ATTR18]]
; TUNIT-NEXT: [[TMP0:%.*]] = load i32, ptr [[A]], align 4
; TUNIT-NEXT: [[TMP1:%.*]] = load i32, ptr [[B]], align 4
; TUNIT-NEXT: [[ADD:%.*]] = add nsw i32 [[TMP0]], [[TMP1]]
; TUNIT-NEXT: ret i32 [[ADD]]
; TUNIT-NEXT: ret i32 8
;
; CGSCC: Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none)
; CGSCC-LABEL: define {{[^@]+}}@may_access_after_return_no_choice2
Expand Down

0 comments on commit 25ece8b

Please sign in to comment.