From 5e8c1c420f6aabfeed6262814680eb045e53cc2e Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Fri, 12 Jul 2024 10:24:55 -0700 Subject: [PATCH] [SandboxIR] IR Tracker This is the first patch in a series of patches for the IR change tracking component of SandboxIR. The tracker collects changes in a vector of `IRChangeBase` objects and provides a `save()`/`accept()`/`revert()` API. Each type of IR changing event is captured by a dedicated subclass of `IRChangeBase`. This patch implements only one of them, that for updating a `sandboxir::Use` source value, named `UseSet`. --- llvm/docs/SandboxIR.md | 11 ++ llvm/include/llvm/SandboxIR/SandboxIR.h | 12 ++ .../include/llvm/SandboxIR/SandboxIRTracker.h | 181 ++++++++++++++++++ llvm/include/llvm/SandboxIR/Use.h | 1 + llvm/lib/SandboxIR/CMakeLists.txt | 1 + llvm/lib/SandboxIR/SandboxIR.cpp | 26 ++- llvm/lib/SandboxIR/SandboxIRTracker.cpp | 84 ++++++++ llvm/unittests/SandboxIR/CMakeLists.txt | 1 + .../SandboxIR/SandboxIRTrackerTest.cpp | 154 +++++++++++++++ 9 files changed, 470 insertions(+), 1 deletion(-) create mode 100644 llvm/include/llvm/SandboxIR/SandboxIRTracker.h create mode 100644 llvm/lib/SandboxIR/SandboxIRTracker.cpp create mode 100644 llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp diff --git a/llvm/docs/SandboxIR.md b/llvm/docs/SandboxIR.md index 8f8752f102c760..29f5e5ea9346f7 100644 --- a/llvm/docs/SandboxIR.md +++ b/llvm/docs/SandboxIR.md @@ -51,3 +51,14 @@ For example, for `sandboxir::User::setOperand(OpIdx, sandboxir::Value *Op)`: - We get the corresponding LLVM User: `llvm::User *LLVMU = cast(Val)` - Next we get the corresponding LLVM Operand: `llvm::Value *LLVMOp = Op->Val` - Finally we modify `LLVMU`'s operand: `LLVMU->setOperand(OpIdx, LLVMOp) + +## IR Change Tracking +Sandbox IR's state can be saved and restored. +This is done with the help of the tracker component that is tightly coupled to the public Sandbox IR API functions. + +To save the state and enable tracking the user needs to call `sandboxir::Context::save()`. +From this point on any change made to the Sandbox IR state will automatically create a change object and register it with the tracker, without any intervention from the user. +The changes are accumulated in a vector within the tracker. + +To rollback to the saved state the user needs to call `sandboxir::Context::revert()`. +Reverting back to the saved state is a matter of going over all the accumulated states in reverse and undoing each individual change. diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index fcb581211736ee..2e2d5668f2a2cd 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -61,6 +61,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/SandboxIR/SandboxIRTracker.h" #include "llvm/SandboxIR/Use.h" #include "llvm/Support/raw_ostream.h" #include @@ -167,6 +168,7 @@ class Value { friend class Context; // For getting `Val`. friend class User; // For getting `Val`. + friend class Use; // For getting `Val`. /// All values point to the context. Context &Ctx; @@ -630,6 +632,8 @@ class BasicBlock : public Value { class Context { protected: LLVMContext &LLVMCtx; + SandboxIRTracker IRTracker; + /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all /// SandboxIR objects. DenseMap> @@ -667,6 +671,14 @@ class Context { public: Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {} + SandboxIRTracker &getTracker() { return IRTracker; } + /// Convenience function for `getTracker().save()` + void save() { IRTracker.save(); } + /// Convenience function for `getTracker().revert()` + void revert() { IRTracker.revert(); } + /// Convenience function for `getTracker().accept()` + void accept() { IRTracker.accept(); } + sandboxir::Value *getValue(llvm::Value *V) const; const sandboxir::Value *getValue(const llvm::Value *V) const { return getValue(const_cast(V)); diff --git a/llvm/include/llvm/SandboxIR/SandboxIRTracker.h b/llvm/include/llvm/SandboxIR/SandboxIRTracker.h new file mode 100644 index 00000000000000..8a819c578a1566 --- /dev/null +++ b/llvm/include/llvm/SandboxIR/SandboxIRTracker.h @@ -0,0 +1,181 @@ +//===- SandboxIRTracker.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is the component of SandboxIR that tracks all changes made to its +// state, such that we can revert the state when needed. +// +// Tracking changes +// ---------------- +// The user needs to call `SandboxIRTracker::save()` to enable tracking changes +// made to SandboxIR. From that point on, any change made to SandboxIR, will +// automatically create a change tracking object and register it with the +// tracker. IR-change objects are subclasses of `IRChangeBase` and get +// registered with the `SandboxIRTracker::track()` function. The change objects +// are saved in the order they are registered with the tracker and are stored in +// the `SandboxIRTracker::Changes` vector. All of this is done transparently to +// the user. +// +// Reverting changes +// ----------------- +// Calling `SandboxIRTracker::revert()` will restore the state saved when +// `SandboxIRTracker::save()` was called. Internally this goes through the +// change objects in `SandboxIRTracker::Changes` in reverse order, calling their +// `IRChangeBase::revert()` function one by one. +// +// Accepting changes +// ----------------- +// The user needs to either revert or accept changes before the tracker object +// is destroyed, or else the tracker destructor will cause a crash. +// This is the job of `SandboxIRTracker::accept()`. Internally this will go +// through the change objects in `SandboxIRTracker::Changes` in order, calling +// `IRChangeBase::accept()`. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SANDBOXIR_SANDBOXIRTRACKER_H +#define LLVM_SANDBOXIR_SANDBOXIRTRACKER_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/SandboxIR/Use.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace llvm::sandboxir { + +class BasicBlock; + +/// Each IR change type has an ID. +enum class TrackID { + UseSet, +}; + +#ifndef NDEBUG +static const char *trackIDToStr(TrackID ID) { + switch (ID) { + case TrackID::UseSet: + return "UseSet"; + } + llvm_unreachable("Unimplemented ID"); +} +#endif // NDEBUG + +class SandboxIRTracker; + +/// The base class for IR Change classes. +class IRChangeBase { +protected: +#ifndef NDEBUG + unsigned Idx = 0; +#endif + const TrackID ID; + SandboxIRTracker &Parent; + +public: + IRChangeBase(TrackID ID, SandboxIRTracker &Parent); + TrackID getTrackID() const { return ID; } + /// This runs when changes get reverted. + virtual void revert() = 0; + /// This runs when changes get accepted. + virtual void accept() = 0; + virtual ~IRChangeBase() = default; +#ifndef NDEBUG + void dumpCommon(raw_ostream &OS) const { + OS << Idx << ". " << trackIDToStr(ID); + } + virtual void dump(raw_ostream &OS) const = 0; + LLVM_DUMP_METHOD virtual void dump() const = 0; +#endif +}; + +/// Change the source Value of a sandboxir::Use. +class UseSet : public IRChangeBase { + Use U; + Value *OrigV = nullptr; + +public: + UseSet(const Use &U, SandboxIRTracker &Tracker) + : IRChangeBase(TrackID::UseSet, Tracker), U(U), OrigV(U.get()) {} + // For isa<> etc. + static bool classof(const IRChangeBase *Other) { + return Other->getTrackID() == TrackID::UseSet; + } + void revert() final { U.set(OrigV); } + void accept() final {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const final { dumpCommon(OS); } + LLVM_DUMP_METHOD void dump() const final; + friend raw_ostream &operator<<(raw_ostream &OS, const UseSet &C) { + C.dump(OS); + return OS; + } +#endif +}; + +/// The tracker collects all the change objects and implements the main API for +/// saving / reverting / accepting. +class SandboxIRTracker { +public: + enum class TrackerState { + Disabled, ///> Tracking is disabled + Record, ///> Tracking changes + Revert, ///> Undoing changes + Accept, ///> Accepting changes + }; + +private: + /// The list of changes that are being tracked. + SmallVector> Changes; + /// The current state of the tracker. + TrackerState State = TrackerState::Disabled; + +public: +#ifndef NDEBUG + /// Helps catch bugs where we are creating new change objects while in the + /// middle of creating other change objects. + bool InMiddleOfCreatingChange = false; +#endif // NDEBUG + + SandboxIRTracker() = default; + ~SandboxIRTracker(); + /// Record \p Change and take ownership. This is the main function used to + /// track Sandbox IR changes. + void track(std::unique_ptr &&Change); + /// \Returns true if the tracker is recording changes. + bool tracking() const { return State == TrackerState::Record; } + /// \Returns the current state of the tracker. + TrackerState getState() const { return State; } + /// Turns on IR tracking. + void save(); + /// Stops tracking and accept changes. + void accept(); + /// Stops tracking and reverts to saved state. + void revert(); + /// \Returns the number of change entries recorded so far. + unsigned size() const { return Changes.size(); } + /// \Returns true if there are no change entries recorded so far. + bool empty() const { return Changes.empty(); } + +#ifndef NDEBUG + /// \Returns the \p Idx'th change. This is used for testing. + IRChangeBase *getChange(unsigned Idx) const { return Changes[Idx].get(); } + void dump(raw_ostream &OS) const; + LLVM_DUMP_METHOD void dump() const; + friend raw_ostream &operator<<(raw_ostream &OS, const SandboxIRTracker &C) { + C.dump(OS); + return OS; + } +#endif // NDEBUG +}; + +} // namespace llvm::sandboxir + +#endif // LLVM_SANDBOXIR_SANDBOXIRTRACKER_H diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h index 33afb54c1ff297..d77b4568d0fab0 100644 --- a/llvm/include/llvm/SandboxIR/Use.h +++ b/llvm/include/llvm/SandboxIR/Use.h @@ -44,6 +44,7 @@ class Use { public: operator Value *() const { return get(); } Value *get() const; + void set(Value *V); class User *getUser() const { return Usr; } unsigned getOperandNo() const; Context *getContext() const { return Ctx; } diff --git a/llvm/lib/SandboxIR/CMakeLists.txt b/llvm/lib/SandboxIR/CMakeLists.txt index 225eca0cadd1ad..74b31fe869aed8 100644 --- a/llvm/lib/SandboxIR/CMakeLists.txt +++ b/llvm/lib/SandboxIR/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMSandboxIR SandboxIR.cpp + SandboxIRTracker.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/SandboxIR diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index a3f350e9ca8b03..a9f564c6591b66 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -16,6 +16,8 @@ using namespace llvm::sandboxir; Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); } +void Use::set(Value *V) { LLVMUse->set(V->Val); } + unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); } #ifndef NDEBUG @@ -112,13 +114,24 @@ void Value::replaceUsesWithIf( User *DstU = cast_or_null(Ctx.getValue(LLVMUse.getUser())); if (DstU == nullptr) return false; - return ShouldReplace(Use(&LLVMUse, DstU, Ctx)); + Use UseToReplace(&LLVMUse, DstU, Ctx); + if (!ShouldReplace(UseToReplace)) + return false; + auto &Tracker = Ctx.getTracker(); + if (Tracker.tracking()) + Tracker.track(std::make_unique(UseToReplace, Tracker)); + return true; }); } void Value::replaceAllUsesWith(Value *Other) { assert(getType() == Other->getType() && "Replacing with Value of different type!"); + auto &Tracker = Ctx.getTracker(); + if (Tracker.tracking()) { + for (auto Use : uses()) + Tracker.track(std::make_unique(Use, Tracker)); + } Val->replaceAllUsesWith(Other->Val); } @@ -208,10 +221,21 @@ bool User::classof(const Value *From) { void User::setOperand(unsigned OperandIdx, Value *Operand) { assert(isa(Val) && "No operands!"); + auto &Tracker = Ctx.getTracker(); + if (Tracker.tracking()) + Tracker.track(std::make_unique(getOperandUse(OperandIdx), Tracker)); cast(Val)->setOperand(OperandIdx, Operand->Val); } bool User::replaceUsesOfWith(Value *FromV, Value *ToV) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.tracking()) { + for (auto OpIdx : seq(0, getNumOperands())) { + auto Use = getOperandUse(OpIdx); + if (Use.get() == FromV) + Tracker.track(std::make_unique(Use, Tracker)); + } + } return cast(Val)->replaceUsesOfWith(FromV->Val, ToV->Val); } diff --git a/llvm/lib/SandboxIR/SandboxIRTracker.cpp b/llvm/lib/SandboxIR/SandboxIRTracker.cpp new file mode 100644 index 00000000000000..0b62df46c020c3 --- /dev/null +++ b/llvm/lib/SandboxIR/SandboxIRTracker.cpp @@ -0,0 +1,84 @@ +//===- SandboxIRTracker.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/SandboxIR/SandboxIRTracker.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Instruction.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include + +using namespace llvm::sandboxir; + +IRChangeBase::IRChangeBase(TrackID ID, SandboxIRTracker &Parent) + : ID(ID), Parent(Parent) { +#ifndef NDEBUG + Idx = Parent.size(); + + assert(!Parent.InMiddleOfCreatingChange && + "We are in the middle of creating another change!"); + if (Parent.tracking()) + Parent.InMiddleOfCreatingChange = true; +#endif // NDEBUG +} + +#ifndef NDEBUG +void UseSet::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + +SandboxIRTracker::~SandboxIRTracker() { + assert(Changes.empty() && "You must accept or revert changes!"); +} + +void SandboxIRTracker::track(std::unique_ptr &&Change) { +#ifndef NDEBUG + assert(State != TrackerState::Revert && + "No changes should be tracked during revert()!"); +#endif // NDEBUG + Changes.push_back(std::move(Change)); + +#ifndef NDEBUG + InMiddleOfCreatingChange = false; +#endif +} + +void SandboxIRTracker::save() { State = TrackerState::Record; } + +void SandboxIRTracker::revert() { + auto SavedState = State; + State = TrackerState::Revert; + for (auto &Change : reverse(Changes)) + Change->revert(); + Changes.clear(); + State = SavedState; +} + +void SandboxIRTracker::accept() { + auto SavedState = State; + State = TrackerState::Accept; + for (auto &Change : Changes) + Change->accept(); + Changes.clear(); + State = SavedState; +} + +#ifndef NDEBUG +void SandboxIRTracker::dump(raw_ostream &OS) const { + for (const auto &ChangePtr : Changes) { + ChangePtr->dump(OS); + OS << "\n"; + } +} +void SandboxIRTracker::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG diff --git a/llvm/unittests/SandboxIR/CMakeLists.txt b/llvm/unittests/SandboxIR/CMakeLists.txt index 362653bfff965d..1bb1a6efbef302 100644 --- a/llvm/unittests/SandboxIR/CMakeLists.txt +++ b/llvm/unittests/SandboxIR/CMakeLists.txt @@ -6,4 +6,5 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(SandboxIRTests SandboxIRTest.cpp + SandboxIRTrackerTest.cpp ) diff --git a/llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp new file mode 100644 index 00000000000000..380d5d9ac1fd8a --- /dev/null +++ b/llvm/unittests/SandboxIR/SandboxIRTrackerTest.cpp @@ -0,0 +1,154 @@ +//===- SandboxIRTrackerTest.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +struct SandboxIRTrackerTest : public testing::Test { + LLVMContext C; + std::unique_ptr M; + + void parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("SandboxIRTrackerTest", errs()); + } + BasicBlock *getBasicBlockByName(Function &F, StringRef Name) { + for (BasicBlock &BB : F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); + } +}; + +TEST_F(SandboxIRTrackerTest, SetOperand) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr float, ptr %ptr, i32 0 + %gep1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %gep0 + store float undef, ptr %gep0 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(&LLVMF); + auto *BB = &*F->begin(); + auto &Tracker = Ctx.getTracker(); + Tracker.save(); + auto It = BB->begin(); + auto *Gep0 = &*It++; + auto *Gep1 = &*It++; + auto *Ld = &*It++; + auto *St = &*It++; + St->setOperand(0, Ld); + EXPECT_EQ(Tracker.size(), 1u); + St->setOperand(1, Gep1); + EXPECT_EQ(Tracker.size(), 2u); + Ld->setOperand(0, Gep1); + EXPECT_EQ(Tracker.size(), 3u); + EXPECT_EQ(St->getOperand(0), Ld); + EXPECT_EQ(St->getOperand(1), Gep1); + EXPECT_EQ(Ld->getOperand(0), Gep1); + + Ctx.getTracker().revert(); + EXPECT_NE(St->getOperand(0), Ld); + EXPECT_EQ(St->getOperand(1), Gep0); + EXPECT_EQ(Ld->getOperand(0), Gep0); +} + +TEST_F(SandboxIRTrackerTest, RUWIf_RAUW_RUOW) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %ld0 = load float, ptr %ptr + %ld1 = load float, ptr %ptr + store float %ld0, ptr %ptr + store float %ld0, ptr %ptr + ret void +} +)IR"); + llvm::Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + llvm::BasicBlock *LLVMBB = &*LLVMF.begin(); + auto &Tracker = Ctx.getTracker(); + Ctx.createFunction(&LLVMF); + auto *BB = cast(Ctx.getValue(LLVMBB)); + auto It = BB->begin(); + sandboxir::Instruction *Ld0 = &*It++; + sandboxir::Instruction *Ld1 = &*It++; + sandboxir::Instruction *St0 = &*It++; + sandboxir::Instruction *St1 = &*It++; + Ctx.save(); + // Check RUWIf when the lambda returns false. + Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; }); + EXPECT_TRUE(Tracker.empty()); + + // Check RUWIf when the lambda returns true. + Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; }); + EXPECT_EQ(Tracker.size(), 2u); + EXPECT_EQ(St0->getOperand(0), Ld1); + EXPECT_EQ(St1->getOperand(0), Ld1); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUWIf user == St0. + Ctx.save(); + Ld0->replaceUsesWithIf( + Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; }); + EXPECT_EQ(St0->getOperand(0), Ld1); + EXPECT_EQ(St1->getOperand(0), Ld0); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUWIf user == St1. + Ctx.save(); + Ld0->replaceUsesWithIf( + Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; }); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld1); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RAUW. + Ctx.save(); + Ld1->replaceAllUsesWith(Ld0); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUOW. + Ctx.save(); + St0->replaceUsesOfWith(Ld0, Ld1); + EXPECT_EQ(Tracker.size(), 1u); + EXPECT_EQ(St0->getOperand(0), Ld1); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + + // Check accept(). + St0->replaceUsesOfWith(Ld0, Ld1); + EXPECT_EQ(Tracker.size(), 1u); + EXPECT_EQ(St0->getOperand(0), Ld1); + Ctx.accept(); + EXPECT_TRUE(Tracker.empty()); + EXPECT_EQ(St0->getOperand(0), Ld1); +}