From cf81c622dd8f772c0f23edd9ca6c4d904ddfc2bb Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 24 Aug 2023 11:43:10 +0300 Subject: [PATCH] Optimize memory usage in analysis. --- .../clad/Differentiator/ReverseModeVisitor.h | 31 +- include/clad/Differentiator/TBRAnalyzer.h | 109 +++-- lib/Differentiator/CladUtils.cpp | 8 +- lib/Differentiator/ReverseModeVisitor.cpp | 8 +- lib/Differentiator/TBRAnalyzer.cpp | 455 +++++++++++------- test/Enzyme/DifferentCladEnzymeDerivatives.C | 6 +- .../LoopsReverseModeComparisonWithClad.C | 14 +- 7 files changed, 378 insertions(+), 253 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 392241374..ef964a155 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -30,7 +30,7 @@ namespace clad { class ReverseModeVisitor : public clang::ConstStmtVisitor, public VisitorBase { - + private: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in @@ -52,6 +52,10 @@ namespace clad { Stmts m_Globals; //// A reference to the output parameter of the gradient function. clang::Expr* m_Result; + /// Based on To-Be-Recorded analysis performed before differentiation, + /// tells UsefulToStoreGlobal whether a variable with a given + /// SourceLocation has to be stored before being changed or not. + std::map m_ToBeRecorded; /// A flag indicating if the Stmt we are currently visiting is inside loop. bool isInsideLoop = false; /// Output variable of vector-valued function @@ -136,7 +140,7 @@ namespace clad { return m_Blocks.back(); else if (d == direction::reverse) return m_Reverse.back(); - else + else return m_EssentialReverse.back(); } /// Create new block. @@ -145,7 +149,7 @@ namespace clad { m_Blocks.push_back({}); else if (d == direction::reverse) m_Reverse.push_back({}); - else + else m_EssentialReverse.push_back({}); return getCurrentBlock(d); } @@ -228,11 +232,6 @@ namespace clad { forceDeclCreation, IS); } - /// Based on To-Be-Recorded analysis performed before differentiation, - /// tells UsefulToStoreGlobal whether a variable with a given - /// SourceLocation has to be stored before changed or not. - std::map m_ToBeRecorded; - /// For an expr E, decides if it is useful to store it in a global temporary /// variable and replace E's further usage by a reference to that variable /// to avoid recomputiation. @@ -429,7 +428,7 @@ namespace clad { clang::QualType xType); /// Allows to easily create and manage a counter for counting the number of - /// executed iterations of a loop. + /// executed iterations of a loop. /// /// It is required to save the number of executed iterations to use the /// same number of iterations in the reverse pass. @@ -448,11 +447,11 @@ namespace clad { /// for counter; otherwise, returns nullptr. clang::Expr* getPush() const { return m_Push; } - /// Returns `clad::pop(_t)` expression if clad tape is used for + /// Returns `clad::pop(_t)` expression if clad tape is used for /// for counter; otherwise, returns nullptr. clang::Expr* getPop() const { return m_Pop; } - /// Returns reference to the last object of the clad tape if clad tape + /// Returns reference to the last object of the clad tape if clad tape /// is used as the counter; otherwise returns reference to the counter /// variable. clang::Expr* getRef() const { return m_Ref; } @@ -494,11 +493,11 @@ namespace clad { /// This class modifies forward and reverse blocks of the loop /// body so that `break` and `continue` statements are correctly - /// handled. `break` and `continue` statements are handled by + /// handled. `break` and `continue` statements are handled by /// enclosing entire reverse block loop body in a switch statement /// and only executing the statements, with the help of case labels, - /// that were executed in the associated forward iteration. This is - /// determined by keeping track of which `break`/`continue` statement + /// that were executed in the associated forward iteration. This is + /// determined by keeping track of which `break`/`continue` statement /// was hit in which iteration and that in turn helps to determine which /// case label should be selected. /// @@ -526,7 +525,7 @@ namespace clad { /// \note `m_ControlFlowTape` is only initialized if the body contains /// `continue` or `break` statement. std::unique_ptr m_ControlFlowTape; - + /// Each `break` and `continue` statement is assigned a unique number, /// starting from 1, that is used as the case label corresponding to that `break`/`continue` /// statement. `m_CaseCounter` stores the value that was used for last @@ -565,7 +564,7 @@ namespace clad { /// control flow switch statement. clang::CaseStmt* GetNextCFCaseStmt(); - /// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` + /// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` /// expression, where `TapeRef` and `m_CurrentCounter` are replaced /// by their actual values respectively. clang::Stmt* CreateCFTapePushExprToCurrentCase(); diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index 978d86b18..c9fcadc84 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -71,12 +71,18 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// a row, which seems uncommon. It's worth considering analysing arrays as /// whole structures instead (just one VarData for the whole array). + struct VarData; + using ObjMap = + std::unordered_map>; + using ArrMap = std::unordered_map, + APIntHash>; + struct VarData { enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE }; union VarDataValue { bool fundData; - std::unordered_map objData; - std::unordered_map arrData; + std::unique_ptr objData; + std::unique_ptr arrData; VarData* refData; VarDataValue() {} ~VarDataValue() {} @@ -85,18 +91,21 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { VarDataValue val; bool isReferenced = false; - /// For non-fundamental type variables, all the child nodes have to be - /// deleted. - ~VarData() { - if (type == OBJ_TYPE) { - for (auto pair : val.objData) { - delete pair.second; - } + VarData() = default; + VarData(VarData&& other) { *this = std::move(other); } + VarData& operator=(VarData&& other) { + type = other.type; + isReferenced = other.isReferenced; + if (type == FUND_TYPE) { + val.fundData = other.val.fundData; } else if (type == ARR_TYPE) { - for (auto pair : val.arrData) { - delete pair.second; - } + val.arrData = std::move(other.val.arrData); + } else if (type == OBJ_TYPE) { + val.objData = std::move(other.val.objData); + } else if (type == REF_TYPE) { + val.refData = other.val.refData; } + return *this; } /// Recursively sets all the leaves' bools to isReq. @@ -116,13 +125,19 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// (e.g. after an if-else statements). Look at the Control Flow section for /// more information. void merge(VarData* mergeData); + void merge(const std::unique_ptr& mergeData) { + merge(mergeData.get()); + } /// Used to recursively copy VarData when separating into different branches /// (e.g. when entering an if-else statements). Look at the Control Flow /// section for more information. refVars stores copied nodes for /// corresponding original nodes in case those are referenced (a referenced /// node is a child to multiple nodes, therefore, we need to make sure we /// don't make multiple copies of it). - VarData* copy(std::unordered_map& refVars); + std::unique_ptr copy(); + std::unique_ptr + copy(std::unordered_map& refVars); + void restoreRefs(std::unordered_map& refVars); }; /// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its @@ -138,7 +153,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Given an Expr* returns its corresponding VarData. VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false); /// Adds the field FD to objData. - void addField(std::unordered_map& objData, + void addField(std::unordered_map>* objData, const FieldDecl* FD); /// Whenever an array element with a non-constant index is set to required /// this function is used to set to required all the array elements that @@ -152,7 +168,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// particular moment. /// Note: 'this' pointer does not have a declaration so nullptr is used as /// its key instead. - using VarsData = std::unordered_map; + using VarsData = + std::unordered_map>; /// Used to find DeclRefExpr's that will be used in the backwards pass. /// In order to be marked as required, a variables has to appear in a place /// where it would have a differential influence and will appear non-linearly @@ -164,7 +181,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { std::map TBRLocs; /// Stores VarsData for every branch in control flow (e.g. if-else statements, /// loops). - std::vector reqStack; + std::vector> reqStack; /// Stores modes in a stack (used to retrieve the old mode after entering /// a new one). std::vector modeStack; @@ -179,7 +196,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// The index of the innermost branch corresponding to a loop (used to handle /// break/continue statements). - size_t innermostLoopBranch = 0; + size_t innermostLoopLayer = 0; /// Tells if the current branch should be deleted instead of merged with /// others. This happens when the branch has a break/continue statement or a /// return expression in it. @@ -204,22 +221,31 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { void setIsRequired(const clang::Expr* E, bool isReq = true); //// Control Flow - /// Creates a new branch as a copy of the last used branch. - void addBranch(); - /// Merges the last into the one right before it and deletes it. - /// If keepNewVars==false, it removes all the variables that are present - /// in the last branch but not the other. If keepNewVars==true, all the new - /// variables are kept. - /// Note: The branch we are merging into is not supposed to have its own - /// local variables (this doesn't matter to the branch being merged). - void mergeAndDelete(bool keepNewVars = false); - /// Swaps the last two branches in the stack. - void swapLastPairOfBranches(); - /// Merges the current branch to a branch with a given index in the stack. - /// Current branch is NOT deleted. - /// Note: The branch we are merging into is not supposed to have its own - /// local variables (this doesn't matter to the branch being merged). - void mergeCurBranchTo(size_t targetBranchNum); + /// Returns the current branch. + VarsData& getCurBranch() { return reqStack.back().back(); } + /// Adds a new layer. + void addLayer() { reqStack.push_back(std::vector()); } + /// Creates a new empty branch. + void addBranch() { + auto& curLayer = reqStack.back(); + size_t len = curLayer.size(); + if (len == curLayer.capacity()) { + curLayer.reserve(len + 1); + } + + curLayer.resize(len + 1); + new (&curLayer[len]) VarsData(); + } + /// Merges the last layer into the one last branch on the previous layer + /// right and deletes the last layer. + void mergeLayer(); + /// Merges the last layer but, unlike the previous method, basically replaces + /// the last branch on the previous layer with the result of merging. After + /// that, removes the last layer. + void mergeLayerOnTop(); + /// Merges the branch with index targetBranch into a sourceBranchNum. + /// No branches are deleted. + void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch); /// Removes local variables from the current branch (uses localVarsStack). /// This is necessary when merging if-else branches. void removeLocalVars(); @@ -242,23 +268,12 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { // Constructor TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) { modeStack.push_back(0); - reqStack.push_back(VarsData()); - } - - // Destructor - ~TBRAnalyzer() { - /// By the end of analysis, reqStack is supposed have just one branch - /// but it's better to iterate through it just to make sure there's no - /// memory leak. - for (auto& branch : reqStack) { - for (auto pair : branch) { - delete pair.second; - } - } + addLayer(); + addBranch(); } /// Returns the result of the whole analysis - const std::map getResult() { return TBRLocs; } + std::map getResult() { return TBRLocs; } /// Visitors diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 2c06b87be..468ab9b88 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -537,12 +537,8 @@ namespace clad { std::vector GetInnermostReturnExpr(clang::Expr* E) { struct Finder : public ConstStmtVisitor { std::vector m_return_exprs; - // Sema* m_Sema; - // ASTContext* m_Context; public: - Finder(/*Sema* S*/) /* : m_Sema(S), m_Context(S.getASTContext()) */ {} - std::vector Find(const clang::Expr* E) { Visit(E); return m_return_exprs; @@ -551,7 +547,7 @@ namespace clad { void VisitBinaryOperator(const clang::BinaryOperator* BO) { if (BO->isAssignmentOp() || BO->isCompoundAssignmentOp()) { Visit(BO->getLHS()); - } else if (BO->isCommaOp()) { + } else if (BO->getOpcode() == clang::BO_Comma) { /**/ } else { assert("Unexpected binary operator!!"); @@ -576,7 +572,7 @@ namespace clad { } void VisitExpr(const clang::Expr* E) { - if (auto PE = dyn_cast(E)) { + if (const auto PE = dyn_cast(E)) { Visit(PE->getSubExpr()); } } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 918a06d22..1353479cb 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -455,7 +455,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivativeAndOverload ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, const DiffRequest& request) { - TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context); + auto* analyzer = new TBRAnalyzer(&m_Context); analyzer->Analyze(FD); m_ToBeRecorded = analyzer->getResult(); delete analyzer; @@ -571,7 +571,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } void ReverseModeVisitor::DifferentiateWithClad() { - TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context); + auto* analyzer = new TBRAnalyzer(&m_Context); analyzer->Analyze(m_Function); m_ToBeRecorded = analyzer->getResult(); delete analyzer; @@ -1897,10 +1897,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // call to gradient and call to original function. At this point, each arg // is either a simple expression or a reference to a temporary variable. // Therefore cloning it has constant complexity. - std::transform(std::begin(CallArgs), - std::end(CallArgs), - std::begin(CallArgs), - [this](Expr* E) { return Clone(E); }); // Recreate the original call expression. Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 80996c4db..7df16ddf5 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -8,11 +8,11 @@ void TBRAnalyzer::VarData::setIsRequired(bool isReq) { if (type == FUND_TYPE) { val.fundData = isReq; } else if (type == OBJ_TYPE) { - for (auto pair : val.objData) { + for (auto& pair : *val.objData) { pair.second->setIsRequired(isReq); } } else if (type == ARR_TYPE) { - for (auto pair : val.arrData) { + for (auto& pair : *val.arrData) { pair.second->setIsRequired(isReq); } } else if (type == REF_TYPE && val.refData) { @@ -24,22 +24,22 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { if (this->type == FUND_TYPE) { this->val.fundData = this->val.fundData || mergeData->val.fundData; } else if (this->type == OBJ_TYPE) { - for (auto pair : this->val.objData) { - pair.second->merge(mergeData->val.objData[pair.first]); + for (auto& pair : *this->val.objData) { + pair.second->merge((*mergeData->val.objData)[pair.first]); } } else if (this->type == ARR_TYPE) { /// FIXME: Currently non-constant indices are not supported in merging. - for (auto pair : this->val.arrData) { - auto it = mergeData->val.arrData.find(pair.first); - if (it != mergeData->val.arrData.end()) { + for (auto& pair : *this->val.arrData) { + auto it = mergeData->val.arrData->find(pair.first); + if (it != mergeData->val.arrData->end()) { pair.second->merge(it->second); } } - for (auto pair : mergeData->val.arrData) { - auto it = this->val.arrData.find(pair.first); - if (it == mergeData->val.arrData.end()) { + for (auto& pair : *mergeData->val.arrData) { + auto it = this->val.arrData->find(pair.first); + if (it == mergeData->val.arrData->end()) { std::unordered_map refVars; - this->val.arrData[pair.first] = pair.second->copy(refVars); + (*this->val.arrData)[pair.first] = pair.second->copy(refVars); } } } else if (this->type == REF_TYPE && this->val.refData) { @@ -47,53 +47,63 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { } } -TBRAnalyzer::VarData* -TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { - VarData* res; +std::unique_ptr TBRAnalyzer::VarData::copy() { + std::unordered_map refVars; + std::unique_ptr res = copy(refVars); + res->restoreRefs(refVars); + return res; +} +std::unique_ptr +TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { + std::unique_ptr res = std::unique_ptr(new VarData()); /// The child node of a reference node should be copied only once. Hence, /// we use refVars to match original referenced nodes to corresponding copies. if (isReferenced) { - auto it = refVars.find(this); - if (it != refVars.end()) { - return it->second; - } else { - res = new VarData(); - refVars[this] = res; - } - } else { - res = new VarData(); + refVars[this] = res.get(); } - res->type = this->type; - if (this->type == FUND_TYPE) { res->val.fundData = this->val.fundData; } else if (this->type == OBJ_TYPE) { - for (auto pair : this->val.objData) - res->val.objData[pair.first] = pair.second->copy(refVars); + res->val.objData = std::unique_ptr(new ObjMap()); + for (auto& pair : *this->val.objData) + (*res->val.objData)[pair.first] = pair.second->copy(refVars); } else if (this->type == ARR_TYPE) { - for (auto pair : this->val.arrData) { - res->val.arrData[pair.first] = pair.second->copy(refVars); + res->val.arrData = std::unique_ptr(new ArrMap()); + for (auto& pair : *this->val.arrData) { + (*res->val.arrData)[pair.first] = pair.second->copy(refVars); } } else if (this->type == REF_TYPE && this->val.refData) { - res->val.refData = this->val.refData->copy(refVars); - res->val.refData->isReferenced = true; + res->val.refData = this->val.refData; } - return res; } +void TBRAnalyzer::VarData::restoreRefs( + std::unordered_map& refVars) { + if (this->type == OBJ_TYPE) { + for (auto& pair : *val.objData) + pair.second->restoreRefs(refVars); + } else if (this->type == ARR_TYPE) { + for (auto& pair : *this->val.arrData) { + pair.second->restoreRefs(refVars); + } + } else if (this->type == REF_TYPE && this->val.refData) { + this->val.refData = refVars[this->val.refData]; + } +} + bool TBRAnalyzer::VarData::findReq() { if (type == FUND_TYPE) { return val.fundData; } else if (type == OBJ_TYPE) { - for (auto pair : val.objData) { + for (auto& pair : *val.objData) { if (pair.second->findReq()) return true; } } else if (type == ARR_TYPE) { - for (auto pair : val.arrData) { + for (auto& pair : *val.arrData) { if (pair.second->findReq()) return true; } @@ -113,15 +123,15 @@ void TBRAnalyzer::VarData::overlay( --i; IdxOrMember& curIdxOrMember = IdxAndMemberSequence[i]; if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::FIELD) { - val.objData[curIdxOrMember.val.field]->overlay(IdxAndMemberSequence, i); + (*val.objData)[curIdxOrMember.val.field]->overlay(IdxAndMemberSequence, i); } else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) { auto idx = curIdxOrMember.val.index; if (idx == llvm::APInt(32, -1, true)) { - for (auto pair : val.arrData) { + for (auto& pair : *val.arrData) { pair.second->overlay(IdxAndMemberSequence, i); } } else { - val.arrData[idx]->overlay(IdxAndMemberSequence, i); + (*val.arrData)[idx]->overlay(IdxAndMemberSequence, i); } } } @@ -141,8 +151,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, /// FUND_TYPE might be set by default earlier. if (baseData->type == VarData::VarDataType::FUND_TYPE) { baseData->type = VarData::VarDataType::OBJ_TYPE; - baseData->val.objData = - std::unordered_map(); + baseData->val.objData = std::unique_ptr(new ObjMap()); } /// if non-const index was found and it is not supposed to be added just /// return the current VarData*. @@ -151,14 +160,14 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, auto& baseObjData = baseData->val.objData; /// Add the current field if it was not added previously - if (baseObjData.find(FD) == baseObjData.end()) { - VarData* FDData = new VarData(); - baseObjData[FD] = FDData; + if (baseObjData->find(FD) == baseObjData->end()) { + (*baseObjData)[FD] = std::unique_ptr(new VarData()); + auto FDData = (*baseObjData)[FD].get(); FDData->type = VarData::VarDataType::UNDEFINED; return FDData; } - return baseObjData[FD]; + return (*baseObjData)[FD].get(); } return nullptr; } @@ -188,8 +197,7 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, /// FUND_TYPE might be set by default earlier. if (baseData->type == VarData::VarDataType::FUND_TYPE) { baseData->type = VarData::VarDataType::ARR_TYPE; - baseData->val.arrData = - std::unordered_map(); + baseData->val.arrData = std::unique_ptr(new ArrMap()); } /// if non-const index was found and it is not supposed to be added just @@ -198,26 +206,26 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, return baseData; auto& baseArrData = baseData->val.arrData; - auto itEnd = baseArrData.end(); + auto itEnd = baseArrData->end(); /// Add the current index if it was not added previously - if (baseArrData.find(idx) == itEnd) { - VarData* idxData = new VarData(); - baseArrData[idx] = idxData; + if (baseArrData->find(idx) == itEnd) { + (*baseArrData)[idx] = std::unique_ptr(new VarData()); + auto& idxData = (*baseArrData)[idx]; /// Since -1 represents non-const indices, whenever we add a new index we /// have to copy the VarData of -1's element (if an element with undefined /// index was used this might be our current element). - auto it = baseArrData.find(llvm::APInt(32, -1, true)); + auto it = baseArrData->find(llvm::APInt(32, -1, true)); if (it != itEnd) { std::unordered_map dummy; idxData = it->second->copy(dummy); } else { idxData->type = VarData::VarDataType::UNDEFINED; } - return idxData; + return idxData.get(); } - return baseArrData[idx]; + return (*baseArrData)[idx].get(); } TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, @@ -226,9 +234,20 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, /// x would be implicitly casted with the * operator). E = E->IgnoreImpCasts(); VarData* EData; - if (auto DRE = dyn_cast(E)) { - if (auto VD = dyn_cast(DRE->getDecl())) { - EData = reqStack.back()[VD]; + if (isa(E) || isa(E)) { + const VarDecl* VD = nullptr; + /// ``this`` does not have a declaration so it is represented with nullptr. + if (auto DRE = dyn_cast(E)) + VD = dyn_cast(DRE->getDecl()); + /// The index i is shifted since otherwise the last value would be i=-1 + /// and size_t can only take positive values. + for (size_t i = reqStack.size(); i > 0; --i) { + auto& branch = reqStack[i - 1].back(); + auto it = branch.find(VD); + if (it != branch.end()) { + EData = it->second.get(); + break; + } } } if (auto ME = dyn_cast(E)) { @@ -237,10 +256,6 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, if (auto ASE = dyn_cast(E)) { EData = getArrSubVarData(ASE, addNonConstIdx); } - /// 'this' pointer is represented as a nullptr. - if (auto TE = dyn_cast(E)) { - EData = reqStack.back()[nullptr]; - } /// If the type of this VarData was not defined previously set it to /// FUND_TYPE. @@ -253,20 +268,17 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, return EData; } -void TBRAnalyzer::addField( - std::unordered_map& objData, - const FieldDecl* FD) { +void TBRAnalyzer::addField(ObjMap* objData, const FieldDecl* FD) { auto varType = FD->getType(); - VarData* data = new VarData(); - objData[FD] = data; + (*objData)[FD] = std::unique_ptr(new VarData()); + VarData* data = (*objData)[FD].get(); if (varType->isReferenceType()) { data->type = VarData::VarDataType::REF_TYPE; data->val.refData = nullptr; } else if (utils::isArrayOrPointerType(varType)) { data->type = VarData::VarDataType::ARR_TYPE; - data->val.arrData = - std::unordered_map(); + data->val.arrData = std::unique_ptr(new ArrMap()); } else if (varType->isBuiltinType()) { data->type = VarData::VarDataType::FUND_TYPE; data->val.fundData = false; @@ -275,7 +287,7 @@ void TBRAnalyzer::addField( auto recordDecl = varType->getAs()->getDecl(); auto& newObjData = data->val.objData; for (auto field : recordDecl->fields()) { - addField(newObjData, field); + addField(newObjData.get(), field); } } } @@ -305,16 +317,21 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { } else return; } + /// Overlay on all the VarData's recursively. - if (auto VD = dyn_cast(innermostDRE->getDecl())) - reqStack.back()[VD]->overlay(IdxAndMemberSequence, - IdxAndMemberSequence.size()); + if (auto VD = dyn_cast(innermostDRE->getDecl())) { + getCurBranch()[VD]->overlay(IdxAndMemberSequence, + IdxAndMemberSequence.size()); + } } void TBRAnalyzer::addVar(const clang::VarDecl* VD) { - // FIXME: this marks the SourceLocation of DeclStmt which doesn't work for - // declarations with multiple VarDecls. - auto& curBranch = reqStack.back(); + /// If a declaration is passed a second time (meaning it is inside a loop), + /// treat it as an assigment operation. + /// FIXME: this marks the SourceLocation of DeclStmt which doesn't work for + /// declarations with multiple VarDecls. + size_t len = reqStack.size(); + auto& curBranch = reqStack[len - 1].back(); if (curBranch.find(VD) != curBranch.end()) { auto& VDData = curBranch[VD]; if (VDData->type == VarData::VarDataType::FUND_TYPE) { @@ -322,14 +339,23 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { (deleteCurBranch ? false : VDData->findReq()); } } + /// The index here is shifted by one since otherwise the loop would end with + /// i=-1 and size_t is positive only. + for (size_t i = len - 1; i > 0; --i) { + auto& branch = reqStack[i - 1].back(); + if (branch.find(VD) != branch.end()) { + curBranch[VD] = branch[VD]->copy(); + break; + } + } if (!localVarsStack.empty()) { localVarsStack.back().push_back(VD); } auto varType = VD->getType(); - VarData* data = new VarData(); - curBranch[VD] = data; + curBranch[VD] = std::unique_ptr(new VarData()); + VarData* data = curBranch[VD].get(); if (varType->isReferenceType()) { data->type = VarData::VarDataType::REF_TYPE; @@ -343,16 +369,15 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { data->type = VarData::VarDataType::OBJ_TYPE; auto recordDecl = pointeeType->getAs()->getDecl(); auto& objData = data->val.objData; - objData = std::unordered_map(); + objData = std::unique_ptr(new ObjMap()); for (auto field : recordDecl->fields()) { - addField(objData, field); + addField(objData.get(), field); } return; } } data->type = VarData::VarDataType::ARR_TYPE; - data->val.arrData = - std::unordered_map(); + data->val.arrData = std::unique_ptr(new ArrMap()); } else if (varType->isBuiltinType()) { data->type = VarData::VarDataType::FUND_TYPE; data->val.fundData = false; @@ -360,9 +385,9 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { data->type = VarData::VarDataType::OBJ_TYPE; auto recordDecl = varType->getAs()->getDecl(); auto& objData = data->val.objData; - objData = std::unordered_map(); + objData = std::unique_ptr(new ObjMap()); for (auto field : recordDecl->fields()) { - addField(objData, field); + addField(objData.get(), field); } } } @@ -385,46 +410,80 @@ void TBRAnalyzer::markLocation(const clang::Expr* E) { TBRLocs[E->getBeginLoc()] = !deleteCurBranch; } -void TBRAnalyzer::addBranch() { - VarsData& curBranch = reqStack.back(); - VarsData newBranch; - std::unordered_map refVars; - for (auto pair : curBranch) { - newBranch[pair.first] = pair.second->copy(refVars); +void TBRAnalyzer::mergeLayer() { + size_t len = reqStack.size(); + auto& removedLayer = reqStack[len - 1]; + auto& curBranch = reqStack[len - 2].back(); + + for (auto& removedBranch : removedLayer) { + for (auto& pair : curBranch) { + auto it = removedBranch.find(pair.first); + if (it != removedBranch.end()) { + pair.second->merge(it->second); + } + } + for (auto& pair : removedBranch) { + auto it = curBranch.find(pair.first); + if (it == curBranch.end()) { + curBranch[pair.first] = std::move(pair.second); + } + } } - reqStack.push_back(newBranch); + reqStack.pop_back(); } -void TBRAnalyzer::mergeAndDelete(bool keepNewVars) { - auto removedBranch = reqStack.back(); - reqStack.pop_back(); - auto& curBranch = reqStack.back(); +void TBRAnalyzer::mergeLayerOnTop() { + size_t len = reqStack.size(); + auto& removedLayer = reqStack[len - 1]; - if (keepNewVars) { - for (auto& pair : curBranch) { - removedBranch[pair.first]->merge(pair.second); - delete pair.second; - pair.second = removedBranch[pair.first]; + if (removedLayer.empty()) { + reqStack.pop_back(); + return; + } + + auto& curBranch = reqStack[len - 2].back(); + + /// First, we merge every branch on the layer with the first one there. + auto branchIter = removedLayer.begin(); + auto branchIterEnd = removedLayer.end(); + auto& firstBranch = *branchIter; + while ((++branchIter) != branchIterEnd) { + for (auto& pair : firstBranch) { + auto elemIter = branchIter->find(pair.first); + if (elemIter != branchIter->end()) { + pair.second->merge(elemIter->second); + } } - } else { - for (auto pair : curBranch) { - pair.second->merge(removedBranch[pair.first]); - delete removedBranch[pair.first]; + for (auto& pair : *branchIter) { + auto elemIter = firstBranch.find(pair.first); + if (elemIter == firstBranch.end()) { + firstBranch[pair.first] = std::move(pair.second); + } } } -} -void TBRAnalyzer::swapLastPairOfBranches() { - size_t s = reqStack.size(); - std::swap(reqStack[s - 1], reqStack[s - 2]); -} + /// Second, we place it on top of the branch on the previous layer's last + /// branch. + for (auto& pair : firstBranch) { + curBranch[pair.first] = std::move(pair.second); + } -void TBRAnalyzer::mergeCurBranchTo(size_t targetBranchNum) { - auto& targetBranch = reqStack[targetBranchNum]; - auto& curBranch = reqStack.back(); + reqStack.pop_back(); +} +void TBRAnalyzer::mergeBranchTo(size_t sourceBranchNum, + VarsData& targetBranch) { for (auto& pair : targetBranch) { - pair.second->merge(curBranch[pair.first]); + /// Index i is shifted by one since otherwise its last value could be -1 + /// and size_t is only positive. + for (size_t i = sourceBranchNum + 1; i > 0; --i) { + auto& sourceBranch = reqStack[i - 1].back(); + auto it = sourceBranch.find(pair.first); + if (it != sourceBranch.end()) { + pair.second->merge(it->second); + break; + } + } } } @@ -447,15 +506,16 @@ void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { void TBRAnalyzer::Analyze(const FunctionDecl* FD) { /// If we are analysing a method add a VarData for 'this' pointer (it is /// represented with nullptr). + if (isa(FD)) { - VarData* data = new VarData(); - reqStack.back()[nullptr] = data; - data->type = VarData::VarDataType::OBJ_TYPE; + auto& thisData = getCurBranch()[nullptr]; + thisData = std::unique_ptr(new VarData()); + thisData->type = VarData::VarDataType::OBJ_TYPE; auto recordDecl = dyn_cast(FD->getParent()); - auto& objData = data->val.objData; - objData = std::unordered_map(); + auto& objData = thisData->val.objData; + objData = std::unique_ptr(new ObjMap()); for (auto field : recordDecl->fields()) { - addField(objData, field); + addField(objData.get(), field); } } auto paramsRef = FD->parameters(); @@ -472,7 +532,7 @@ void TBRAnalyzer::VisitCompoundStmt(const CompoundStmt* CS) { void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (auto VD = dyn_cast(DRE->getDecl())) { - auto& curBranch = reqStack.back(); + auto& curBranch = getCurBranch(); // FIXME: this is only necessary to ensure global variables are added. // It doesn't make any sense to first add variables when visiting DeclStmt // and then checking if they were added while visiting DeclRefExpr. @@ -518,13 +578,13 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { setMode(Mode::markingMode); Visit(init); resetMode(); - auto VDExpr = reqStack.back()[VD]; + auto& VDExpr = getCurBranch()[VD]; /// if the declared variable is ref type attach its VarData* to the /// VarData* of the RHS variable. if (VDExpr->type == VarData::VarDataType::REF_TYPE) { auto RHSExpr = getExprVarData(utils::GetInnermostReturnExpr(init)[0]); VDExpr->val.refData = RHSExpr; - RHSExpr->isReferenced = VDExpr; + RHSExpr->isReferenced = true; } } } @@ -537,11 +597,12 @@ void TBRAnalyzer::VisitConditionalOperator( Visit(CO->getCond()); resetMode(); + addLayer(); addBranch(); Visit(CO->getTrueExpr()); - swapLastPairOfBranches(); + addBranch(); Visit(CO->getFalseExpr()); - mergeAndDelete(); + mergeLayerOnTop(); } void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { @@ -678,14 +739,14 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { /// ... - - /// ... - - addBranch(); - - bool thenBranchNotDeleted = true; - bool elseBranchNotDeleted = true; auto thenBranch = If->getThen(); auto elseBranch = If->getElse(); + + localVarsStack.emplace_back(); + addLayer(); + if (thenBranch) { - localVarsStack.push_back(std::vector()); + addBranch(); Visit(cond); if (condVarDecl) addVar(condVarDecl); @@ -698,20 +759,13 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { if (deleteCurBranch) { /// This section is performed if this branch had break/continue/return /// and, therefore, shouldn't be merged. - reqStack.pop_back(); + reqStack.back().pop_back(); deleteCurBranch = false; - thenBranchNotDeleted = false; - } else { - /// We have to remove local variables from then-branch to later merge the - /// else-branch into it. - removeLocalVars(); - localVarsStack.pop_back(); } } if (elseBranch) { - if (thenBranchNotDeleted) - swapLastPairOfBranches(); + addBranch(); Visit(cond); if (condVarDecl) addVar(condVarDecl); @@ -721,23 +775,23 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { resetMode(); } Visit(elseBranch); - if (deleteCurBranch && thenBranchNotDeleted) { + if (deleteCurBranch) { /// This section is performed if this branch had break/continue/return /// and, therefore, shouldn't be merged. - reqStack.pop_back(); + reqStack.back().pop_back(); deleteCurBranch = false; - elseBranchNotDeleted = false; } } - if (thenBranchNotDeleted && elseBranchNotDeleted) - mergeAndDelete(); + mergeLayerOnTop(); + removeLocalVars(); + localVarsStack.pop_back(); } void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { auto body = WS->getBody(); auto cond = WS->getCond(); - size_t backupILB = innermostLoopBranch; + size_t backupILB = innermostLoopLayer; bool backupFLP = firstLoopPass; bool backupDCB = deleteCurBranch; /// Let's assume we have a section of code structured like this @@ -766,34 +820,40 @@ void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { /// ... - Visit(cond); + addLayer(); + addBranch(); + addBranch(); + + addLayer(); addBranch(); addBranch(); /// First pass - innermostLoopBranch = reqStack.size() - 2; + innermostLoopLayer = reqStack.size() - 1; firstLoopPass = true; if (body) Visit(body); if (deleteCurBranch) { - reqStack.pop_back(); + reqStack.back().pop_back(); deleteCurBranch = backupDCB; } else { Visit(cond); - mergeAndDelete(/*keepNewVars=*/true); } + --innermostLoopLayer; + mergeLayer(); + mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); /// Second pass - --innermostLoopBranch; firstLoopPass = false; if (body) Visit(body); if (deleteCurBranch) - reqStack.pop_back(); + reqStack.back().pop_back(); else { Visit(cond); - mergeAndDelete(); } + mergeLayer(); - innermostLoopBranch = backupILB; + innermostLoopLayer = backupILB; firstLoopPass = backupFLP; deleteCurBranch = backupDCB; } @@ -804,7 +864,7 @@ void TBRAnalyzer::VisitForStmt(const clang::ForStmt* FS) { auto init = FS->getInit(); auto cond = FS->getCond(); auto incr = FS->getInc(); - size_t backupILB = innermostLoopBranch; + size_t backupILB = innermostLoopLayer; bool backupFLP = firstLoopPass; bool backupDCB = deleteCurBranch; /// The logic here is virtually the same as with while-loop. Take a look at @@ -816,42 +876,47 @@ void TBRAnalyzer::VisitForStmt(const clang::ForStmt* FS) { } if (cond) Visit(cond); + addLayer(); + addBranch(); addBranch(); if (condVar) addVar(condVar); + addLayer(); + addBranch(); addBranch(); /// First pass - innermostLoopBranch = reqStack.size() - 2; + innermostLoopLayer = reqStack.size() - 1; firstLoopPass = true; if (body) Visit(body); if (deleteCurBranch) { - reqStack.pop_back(); + reqStack.back().pop_back(); deleteCurBranch = backupDCB; } else { if (incr) Visit(incr); if (cond) Visit(cond); - mergeAndDelete(/*keepNewVars=*/true); } + --innermostLoopLayer; + mergeLayer(); + mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); /// Second pass - --innermostLoopBranch; firstLoopPass = false; if (body) Visit(body); if (incr) Visit(incr); if (deleteCurBranch) - reqStack.pop_back(); + reqStack.back().pop_back(); else { if (cond) Visit(cond); - mergeAndDelete(); } + mergeLayer(); - innermostLoopBranch = backupILB; + innermostLoopLayer = backupILB; firstLoopPass = backupFLP; deleteCurBranch = backupDCB; } @@ -859,7 +924,7 @@ void TBRAnalyzer::VisitForStmt(const clang::ForStmt* FS) { void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { auto body = DS->getBody(); auto cond = DS->getCond(); - size_t backupILB = innermostLoopBranch; + size_t backupILB = innermostLoopLayer; bool backupFLP = firstLoopPass; bool backupDCB = deleteCurBranch; @@ -870,10 +935,14 @@ void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { /// having two loop branches is necessary for handling continue statements /// so we can't just remove one of them. + addLayer(); + addBranch(); + addBranch(); + addLayer(); addBranch(); addBranch(); /// First pass - innermostLoopBranch = reqStack.size() - 2; + innermostLoopLayer = reqStack.size() - 2; firstLoopPass = true; if (body) Visit(body); @@ -882,21 +951,58 @@ void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { deleteCurBranch = backupDCB; } else { Visit(cond); - mergeAndDelete(/*keepNewVars=*/true); + mergeLayer(); } /// Second pass - --innermostLoopBranch; + --innermostLoopLayer; firstLoopPass = false; if (body) Visit(body); Visit(cond); if (deleteCurBranch) { reqStack.pop_back(); - mergeAndDelete(); + mergeLayer(); } - innermostLoopBranch = backupILB; + innermostLoopLayer = backupILB; + firstLoopPass = backupFLP; + deleteCurBranch = backupDCB; + + addLayer(); + addBranch(); + addBranch(); + + addLayer(); + addBranch(); + addBranch(); + /// First pass + innermostLoopLayer = reqStack.size() - 1; + firstLoopPass = true; + if (body) + Visit(body); + if (deleteCurBranch) { + reqStack.back().pop_back(); + deleteCurBranch = backupDCB; + } else { + Visit(cond); + } + --innermostLoopLayer; + mergeLayer(); + mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); + + /// Second pass + firstLoopPass = false; + if (body) + Visit(body); + if (deleteCurBranch) + reqStack.back().pop_back(); + else { + Visit(cond); + } + mergeLayer(); + + innermostLoopLayer = backupILB; firstLoopPass = backupFLP; deleteCurBranch = backupDCB; } @@ -915,11 +1021,18 @@ void TBRAnalyzer::VisitContinueStmt(const clang::ContinueStmt* CS) { /// followed by another iteration. We have to either add an additional branch /// or find a better solution. (However, this bug will matter only in really /// rare cases) - mergeCurBranchTo(innermostLoopBranch); + + auto& targetLayer1 = reqStack[innermostLoopLayer]; + auto& targetBranch1 = targetLayer1[targetLayer1.size() - 2]; + size_t sourceBranchNum = reqStack.size() - 1; + mergeBranchTo(sourceBranchNum, targetBranch1); /// After the continue statement, this branch cannot be followed by any other /// code so we can delete it. - if (firstLoopPass) - mergeCurBranchTo(innermostLoopBranch - 1); + if (firstLoopPass) { + auto& targetLayer2 = reqStack[innermostLoopLayer - 1]; + auto& targetBranch2 = targetLayer2[targetLayer2.size() - 2]; + mergeBranchTo(sourceBranchNum, targetBranch2); + } deleteCurBranch = true; } @@ -928,8 +1041,11 @@ void TBRAnalyzer::VisitBreakStmt(const clang::BreakStmt* BS) { /// ... - - /// And so this break could be the end of this loop. So we have to merge /// the current branch into the first branch on the diagram. - if (!firstLoopPass) - mergeCurBranchTo(innermostLoopBranch); + if (!firstLoopPass) { + auto& targetLayer = reqStack[innermostLoopLayer]; + auto& targetBranch = targetLayer[targetLayer.size() - 2]; + mergeBranchTo(reqStack.size() - 1, targetBranch); + } /// After the break statement, this branch cannot be followed by any other /// code so we can delete it. deleteCurBranch = true; @@ -1001,6 +1117,9 @@ void TBRAnalyzer::VisitMemberExpr(const clang::MemberExpr* ME) { void TBRAnalyzer::VisitArraySubscriptExpr( const clang::ArraySubscriptExpr* ASE) { + setMode(0); + Visit(ASE->getBase()); + resetMode(); setIsRequired(dyn_cast(ASE)); setMode(Mode::markingMode | Mode::nonLinearMode); Visit(ASE->getIdx()); @@ -1016,7 +1135,7 @@ void TBRAnalyzer::VisitInitListExpr(const clang::InitListExpr* ILE) { } void TBRAnalyzer::removeLocalVars() { - auto& curBranch = reqStack.back(); + auto& curBranch = getCurBranch(); for (auto VD : localVarsStack.back()) curBranch.erase(VD); } diff --git a/test/Enzyme/DifferentCladEnzymeDerivatives.C b/test/Enzyme/DifferentCladEnzymeDerivatives.C index c3cd617dd..f45e8e1d3 100644 --- a/test/Enzyme/DifferentCladEnzymeDerivatives.C +++ b/test/Enzyme/DifferentCladEnzymeDerivatives.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out | FileCheck %s +// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out // RUN: ./DifferentCladEnzymeDerivatives.out // CHECK-NOT: {{.*error|warning|note:.*}} // REQUIRES: Enzyme @@ -32,6 +32,6 @@ double foo(double x, double y){ // CHECK-NEXT: } int main(){ - auto grad = clad::gradient(foo); + auto grad = clad::gradient(foo); auto gradEnzyme = clad::gradient(foo); -} \ No newline at end of file +} diff --git a/test/Enzyme/LoopsReverseModeComparisonWithClad.C b/test/Enzyme/LoopsReverseModeComparisonWithClad.C index 73a3e94ec..2a8e699b7 100644 --- a/test/Enzyme/LoopsReverseModeComparisonWithClad.C +++ b/test/Enzyme/LoopsReverseModeComparisonWithClad.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oEnzymeLoops.out 2>&1 | FileCheck %s +// RUN: %cladclang %s -I%S/../../include -oEnzymeLoops.out // RUN: ./EnzymeLoops.out | FileCheck -check-prefix=CHECK-EXEC %s // REQUIRES: Enzyme // CHECK-NOT: {{.*error|warning|note:.*}} @@ -92,7 +92,7 @@ double f6 (double i, double j) { double fn7(double i, double j) { double a = 0; int counter = 3; - while (counter--) + while (counter--) a += i*i + j; return a; } @@ -229,7 +229,7 @@ double fn13(double i, double j) { double fn14(double i, double j) { int choice = 5; double res = 0; - for (int counter=0; counter