From 21c853655b1471e3287b907a94e7c5d3441d7cab Mon Sep 17 00:00:00 2001 From: vporpo Date: Wed, 17 Jul 2024 21:57:52 -0700 Subject: [PATCH] [SandboxIR] IR Tracker (#99238) This is the first patch in a series of patches for the IR change tracking component of SandboxIR. The tracker collects changes in a vector of `IRChangeBase` objects and provides a `save()`/`accept()`/`revert()` API. Each type of IR changing event is captured by a dedicated subclass of `IRChangeBase`. This patch implements only one of them, that for updating a `sandboxir::Use` source value, named `UseSet`. --- llvm/docs/SandboxIR.md | 18 +++ llvm/include/llvm/SandboxIR/SandboxIR.h | 12 ++ llvm/include/llvm/SandboxIR/Tracker.h | 155 +++++++++++++++++++++++ llvm/include/llvm/SandboxIR/Use.h | 1 + llvm/lib/SandboxIR/CMakeLists.txt | 1 + llvm/lib/SandboxIR/SandboxIR.cpp | 26 +++- llvm/lib/SandboxIR/Tracker.cpp | 82 ++++++++++++ llvm/unittests/SandboxIR/CMakeLists.txt | 1 + llvm/unittests/SandboxIR/TrackerTest.cpp | 148 ++++++++++++++++++++++ 9 files changed, 443 insertions(+), 1 deletion(-) create mode 100644 llvm/include/llvm/SandboxIR/Tracker.h create mode 100644 llvm/lib/SandboxIR/Tracker.cpp create mode 100644 llvm/unittests/SandboxIR/TrackerTest.cpp diff --git a/llvm/docs/SandboxIR.md b/llvm/docs/SandboxIR.md index 8f8752f102c7601..3b792659bb59ba4 100644 --- a/llvm/docs/SandboxIR.md +++ b/llvm/docs/SandboxIR.md @@ -51,3 +51,21 @@ For example, for `sandboxir::User::setOperand(OpIdx, sandboxir::Value *Op)`: - We get the corresponding LLVM User: `llvm::User *LLVMU = cast(Val)` - Next we get the corresponding LLVM Operand: `llvm::Value *LLVMOp = Op->Val` - Finally we modify `LLVMU`'s operand: `LLVMU->setOperand(OpIdx, LLVMOp) + +## IR Change Tracking +Sandbox IR's state can be saved and restored. +This is done with the help of the tracker component that is tightly coupled to the public Sandbox IR API functions. +Please note that nested saves/restores are currently not supported. + +To save the state and enable tracking the user needs to call `sandboxir::Context::save()`. +From this point on any change made to the Sandbox IR state will automatically create a change object and register it with the tracker, without any intervention from the user. +The changes are accumulated in a vector within the tracker. + +To rollback to the saved state the user needs to call `sandboxir::Context::revert()`. +Reverting back to the saved state is a matter of going over all the accumulated changes in reverse and undoing each individual change. + +To accept the changes made to the IR the user needs to call `sandboxir::Context::accept()`. +Internally this will go through the changes and run any finalization required. + +Please note that after a call to `revert()` or `accept()` tracking will stop. +To start tracking again, the user needs to call `save()`. diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 473bd93aea7c106..c5d59ba47ca310a 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -61,6 +61,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/SandboxIR/Tracker.h" #include "llvm/SandboxIR/Use.h" #include "llvm/Support/raw_ostream.h" #include @@ -171,6 +172,7 @@ class Value { friend class Context; // For getting `Val`. friend class User; // For getting `Val`. + friend class Use; // For getting `Val`. /// All values point to the context. Context &Ctx; @@ -641,6 +643,8 @@ class BasicBlock : public Value { class Context { protected: LLVMContext &LLVMCtx; + Tracker IRTracker; + /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all /// SandboxIR objects. DenseMap> @@ -680,6 +684,14 @@ class Context { public: Context(LLVMContext &LLVMCtx) : LLVMCtx(LLVMCtx) {} + Tracker &getTracker() { return IRTracker; } + /// Convenience function for `getTracker().save()` + void save() { IRTracker.save(); } + /// Convenience function for `getTracker().revert()` + void revert() { IRTracker.revert(); } + /// Convenience function for `getTracker().accept()` + void accept() { IRTracker.accept(); } + sandboxir::Value *getValue(llvm::Value *V) const; const sandboxir::Value *getValue(const llvm::Value *V) const { return getValue(const_cast(V)); diff --git a/llvm/include/llvm/SandboxIR/Tracker.h b/llvm/include/llvm/SandboxIR/Tracker.h new file mode 100644 index 000000000000000..2d0904f5665b139 --- /dev/null +++ b/llvm/include/llvm/SandboxIR/Tracker.h @@ -0,0 +1,155 @@ +//===- Tracker.h ------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is the component of SandboxIR that tracks all changes made to its +// state, such that we can revert the state when needed. +// +// Tracking changes +// ---------------- +// The user needs to call `Tracker::save()` to enable tracking changes +// made to SandboxIR. From that point on, any change made to SandboxIR, will +// automatically create a change tracking object and register it with the +// tracker. IR-change objects are subclasses of `IRChangeBase` and get +// registered with the `Tracker::track()` function. The change objects +// are saved in the order they are registered with the tracker and are stored in +// the `Tracker::Changes` vector. All of this is done transparently to +// the user. +// +// Reverting changes +// ----------------- +// Calling `Tracker::revert()` will restore the state saved when +// `Tracker::save()` was called. Internally this goes through the +// change objects in `Tracker::Changes` in reverse order, calling their +// `IRChangeBase::revert()` function one by one. +// +// Accepting changes +// ----------------- +// The user needs to either revert or accept changes before the tracker object +// is destroyed. This is enforced in the tracker's destructor. +// This is the job of `Tracker::accept()`. Internally this will go +// through the change objects in `Tracker::Changes` in order, calling +// `IRChangeBase::accept()`. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SANDBOXIR_TRACKER_H +#define LLVM_SANDBOXIR_TRACKER_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/SandboxIR/Use.h" +#include "llvm/Support/Debug.h" +#include +#include + +namespace llvm::sandboxir { + +class BasicBlock; +class Tracker; + +/// The base class for IR Change classes. +class IRChangeBase { +protected: + Tracker &Parent; + +public: + IRChangeBase(Tracker &Parent); + /// This runs when changes get reverted. + virtual void revert() = 0; + /// This runs when changes get accepted. + virtual void accept() = 0; + virtual ~IRChangeBase() = default; +#ifndef NDEBUG + /// \Returns the index of this change by iterating over all changes in the + /// tracker. This is only used for debugging. + unsigned getIdx() const; + void dumpCommon(raw_ostream &OS) const { OS << getIdx() << ". "; } + virtual void dump(raw_ostream &OS) const = 0; + LLVM_DUMP_METHOD virtual void dump() const = 0; + friend raw_ostream &operator<<(raw_ostream &OS, const IRChangeBase &C) { + C.dump(OS); + return OS; + } +#endif +}; + +/// Tracks the change of the source Value of a sandboxir::Use. +class UseSet : public IRChangeBase { + Use U; + Value *OrigV = nullptr; + +public: + UseSet(const Use &U, Tracker &Tracker) + : IRChangeBase(Tracker), U(U), OrigV(U.get()) {} + void revert() final { U.set(OrigV); } + void accept() final {} +#ifndef NDEBUG + void dump(raw_ostream &OS) const final { + dumpCommon(OS); + OS << "UseSet"; + } + LLVM_DUMP_METHOD void dump() const final; +#endif +}; + +/// The tracker collects all the change objects and implements the main API for +/// saving / reverting / accepting. +class Tracker { +public: + enum class TrackerState { + Disabled, ///> Tracking is disabled + Record, ///> Tracking changes + }; + +private: + /// The list of changes that are being tracked. + SmallVector> Changes; +#ifndef NDEBUG + friend unsigned IRChangeBase::getIdx() const; // For accessing `Changes`. +#endif + /// The current state of the tracker. + TrackerState State = TrackerState::Disabled; + +public: +#ifndef NDEBUG + /// Helps catch bugs where we are creating new change objects while in the + /// middle of creating other change objects. + bool InMiddleOfCreatingChange = false; +#endif // NDEBUG + + Tracker() = default; + ~Tracker(); + /// Record \p Change and take ownership. This is the main function used to + /// track Sandbox IR changes. + void track(std::unique_ptr &&Change); + /// \Returns true if the tracker is recording changes. + bool isTracking() const { return State == TrackerState::Record; } + /// \Returns the current state of the tracker. + TrackerState getState() const { return State; } + /// Turns on IR tracking. + void save(); + /// Stops tracking and accept changes. + void accept(); + /// Stops tracking and reverts to saved state. + void revert(); + +#ifndef NDEBUG + void dump(raw_ostream &OS) const; + LLVM_DUMP_METHOD void dump() const; + friend raw_ostream &operator<<(raw_ostream &OS, const Tracker &Tracker) { + Tracker.dump(OS); + return OS; + } +#endif // NDEBUG +}; + +} // namespace llvm::sandboxir + +#endif // LLVM_SANDBOXIR_TRACKER_H diff --git a/llvm/include/llvm/SandboxIR/Use.h b/llvm/include/llvm/SandboxIR/Use.h index 33afb54c1ff2975..d77b4568d0fab08 100644 --- a/llvm/include/llvm/SandboxIR/Use.h +++ b/llvm/include/llvm/SandboxIR/Use.h @@ -44,6 +44,7 @@ class Use { public: operator Value *() const { return get(); } Value *get() const; + void set(Value *V); class User *getUser() const { return Usr; } unsigned getOperandNo() const; Context *getContext() const { return Ctx; } diff --git a/llvm/lib/SandboxIR/CMakeLists.txt b/llvm/lib/SandboxIR/CMakeLists.txt index 225eca0cadd1ad1..6c0666b186b8a64 100644 --- a/llvm/lib/SandboxIR/CMakeLists.txt +++ b/llvm/lib/SandboxIR/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_component_library(LLVMSandboxIR SandboxIR.cpp + Tracker.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/SandboxIR diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 2984c6eaccd6408..944869a37989c80 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -16,6 +16,8 @@ using namespace llvm::sandboxir; Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); } +void Use::set(Value *V) { LLVMUse->set(V->Val); } + unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); } #ifndef NDEBUG @@ -115,13 +117,24 @@ void Value::replaceUsesWithIf( User *DstU = cast_or_null(Ctx.getValue(LLVMUse.getUser())); if (DstU == nullptr) return false; - return ShouldReplace(Use(&LLVMUse, DstU, Ctx)); + Use UseToReplace(&LLVMUse, DstU, Ctx); + if (!ShouldReplace(UseToReplace)) + return false; + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique(UseToReplace, Tracker)); + return true; }); } void Value::replaceAllUsesWith(Value *Other) { assert(getType() == Other->getType() && "Replacing with Value of different type!"); + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) { + for (auto Use : uses()) + Tracker.track(std::make_unique(Use, Tracker)); + } // We are delegating RAUW to LLVM IR's RAUW. Val->replaceAllUsesWith(Other->Val); } @@ -212,11 +225,22 @@ bool User::classof(const Value *From) { void User::setOperand(unsigned OperandIdx, Value *Operand) { assert(isa(Val) && "No operands!"); + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) + Tracker.track(std::make_unique(getOperandUse(OperandIdx), Tracker)); // We are delegating to llvm::User::setOperand(). cast(Val)->setOperand(OperandIdx, Operand->Val); } bool User::replaceUsesOfWith(Value *FromV, Value *ToV) { + auto &Tracker = Ctx.getTracker(); + if (Tracker.isTracking()) { + for (auto OpIdx : seq(0, getNumOperands())) { + auto Use = getOperandUse(OpIdx); + if (Use.get() == FromV) + Tracker.track(std::make_unique(Use, Tracker)); + } + } // We are delegating RUOW to LLVM IR's RUOW. return cast(Val)->replaceUsesOfWith(FromV->Val, ToV->Val); } diff --git a/llvm/lib/SandboxIR/Tracker.cpp b/llvm/lib/SandboxIR/Tracker.cpp new file mode 100644 index 000000000000000..1182f5c55d10b12 --- /dev/null +++ b/llvm/lib/SandboxIR/Tracker.cpp @@ -0,0 +1,82 @@ +//===- Tracker.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/SandboxIR/Tracker.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Instruction.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include + +using namespace llvm::sandboxir; + +IRChangeBase::IRChangeBase(Tracker &Parent) : Parent(Parent) { +#ifndef NDEBUG + assert(!Parent.InMiddleOfCreatingChange && + "We are in the middle of creating another change!"); + if (Parent.isTracking()) + Parent.InMiddleOfCreatingChange = true; +#endif // NDEBUG +} + +#ifndef NDEBUG +unsigned IRChangeBase::getIdx() const { + auto It = + find_if(Parent.Changes, [this](auto &Ptr) { return Ptr.get() == this; }); + return It - Parent.Changes.begin(); +} + +void UseSet::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG + +Tracker::~Tracker() { + assert(Changes.empty() && "You must accept or revert changes!"); +} + +void Tracker::track(std::unique_ptr &&Change) { + assert(State == TrackerState::Record && "The tracker should be tracking!"); + Changes.push_back(std::move(Change)); + +#ifndef NDEBUG + InMiddleOfCreatingChange = false; +#endif +} + +void Tracker::save() { State = TrackerState::Record; } + +void Tracker::revert() { + assert(State == TrackerState::Record && "Forgot to save()!"); + State = TrackerState::Disabled; + for (auto &Change : reverse(Changes)) + Change->revert(); + Changes.clear(); +} + +void Tracker::accept() { + assert(State == TrackerState::Record && "Forgot to save()!"); + State = TrackerState::Disabled; + for (auto &Change : Changes) + Change->accept(); + Changes.clear(); +} + +#ifndef NDEBUG +void Tracker::dump(raw_ostream &OS) const { + for (const auto &ChangePtr : Changes) { + ChangePtr->dump(OS); + OS << "\n"; + } +} +void Tracker::dump() const { + dump(dbgs()); + dbgs() << "\n"; +} +#endif // NDEBUG diff --git a/llvm/unittests/SandboxIR/CMakeLists.txt b/llvm/unittests/SandboxIR/CMakeLists.txt index 362653bfff965db..3f43f6337b919bc 100644 --- a/llvm/unittests/SandboxIR/CMakeLists.txt +++ b/llvm/unittests/SandboxIR/CMakeLists.txt @@ -6,4 +6,5 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(SandboxIRTests SandboxIRTest.cpp + TrackerTest.cpp ) diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp new file mode 100644 index 000000000000000..f090dc521c32b0e --- /dev/null +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -0,0 +1,148 @@ +//===- TrackerTest.cpp ----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/SandboxIR/SandboxIR.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +struct TrackerTest : public testing::Test { + LLVMContext C; + std::unique_ptr M; + + void parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + M = parseAssemblyString(IR, Err, C); + if (!M) + Err.print("TrackerTest", errs()); + } + BasicBlock *getBasicBlockByName(Function &F, StringRef Name) { + for (BasicBlock &BB : F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); + } +}; + +TEST_F(TrackerTest, SetOperand) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr float, ptr %ptr, i32 0 + %gep1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %gep0 + store float undef, ptr %gep0 + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(&LLVMF); + auto *BB = &*F->begin(); + auto &Tracker = Ctx.getTracker(); + Tracker.save(); + auto It = BB->begin(); + auto *Gep0 = &*It++; + auto *Gep1 = &*It++; + auto *Ld = &*It++; + auto *St = &*It++; + St->setOperand(0, Ld); + St->setOperand(1, Gep1); + Ld->setOperand(0, Gep1); + EXPECT_EQ(St->getOperand(0), Ld); + EXPECT_EQ(St->getOperand(1), Gep1); + EXPECT_EQ(Ld->getOperand(0), Gep1); + + Ctx.getTracker().revert(); + EXPECT_NE(St->getOperand(0), Ld); + EXPECT_EQ(St->getOperand(1), Gep0); + EXPECT_EQ(Ld->getOperand(0), Gep0); +} + +TEST_F(TrackerTest, RUWIf_RAUW_RUOW) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %ld0 = load float, ptr %ptr + %ld1 = load float, ptr %ptr + store float %ld0, ptr %ptr + store float %ld0, ptr %ptr + ret void +} +)IR"); + llvm::Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + llvm::BasicBlock *LLVMBB = &*LLVMF.begin(); + Ctx.createFunction(&LLVMF); + auto *BB = cast(Ctx.getValue(LLVMBB)); + auto It = BB->begin(); + sandboxir::Instruction *Ld0 = &*It++; + sandboxir::Instruction *Ld1 = &*It++; + sandboxir::Instruction *St0 = &*It++; + sandboxir::Instruction *St1 = &*It++; + Ctx.save(); + // Check RUWIf when the lambda returns false. + Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return false; }); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUWIf when the lambda returns true. + Ld0->replaceUsesWithIf(Ld1, [](const sandboxir::Use &Use) { return true; }); + EXPECT_EQ(St0->getOperand(0), Ld1); + EXPECT_EQ(St1->getOperand(0), Ld1); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUWIf user == St0. + Ctx.save(); + Ld0->replaceUsesWithIf( + Ld1, [St0](const sandboxir::Use &Use) { return Use.getUser() == St0; }); + EXPECT_EQ(St0->getOperand(0), Ld1); + EXPECT_EQ(St1->getOperand(0), Ld0); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUWIf user == St1. + Ctx.save(); + Ld0->replaceUsesWithIf( + Ld1, [St1](const sandboxir::Use &Use) { return Use.getUser() == St1; }); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld1); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RAUW. + Ctx.save(); + Ld1->replaceAllUsesWith(Ld0); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + EXPECT_EQ(St1->getOperand(0), Ld0); + + // Check RUOW. + Ctx.save(); + St0->replaceUsesOfWith(Ld0, Ld1); + EXPECT_EQ(St0->getOperand(0), Ld1); + Ctx.revert(); + EXPECT_EQ(St0->getOperand(0), Ld0); + + // Check accept(). + Ctx.save(); + St0->replaceUsesOfWith(Ld0, Ld1); + EXPECT_EQ(St0->getOperand(0), Ld1); + Ctx.accept(); + EXPECT_EQ(St0->getOperand(0), Ld1); +}