From e9257f97e4f653b112bb377e885226c8647835fc Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 24 Aug 2023 11:43:10 +0300 Subject: [PATCH] Optimize memory usage in analysis. --- include/clad/Differentiator/CladUtils.h | 12 +- .../clad/Differentiator/ReverseModeVisitor.h | 31 +- include/clad/Differentiator/TBRAnalyzer.h | 112 ++-- lib/Differentiator/CladUtils.cpp | 14 +- lib/Differentiator/ReverseModeVisitor.cpp | 22 +- lib/Differentiator/TBRAnalyzer.cpp | 591 +++++++++++------- test/Enzyme/DifferentCladEnzymeDerivatives.C | 6 +- .../LoopsReverseModeComparisonWithClad.C | 14 +- 8 files changed, 469 insertions(+), 333 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 96d65ae5e..81b9d6d3d 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -28,7 +28,7 @@ namespace clad { std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD); /// Creates and returns a compound statement having statements as follows: - /// {`S`, all the statement of `initial` in sequence} + /// {`S`, all the statement of `initial` in sequence} clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C, clang::Stmt* initial, clang::Stmt* S); @@ -38,7 +38,7 @@ namespace clad { clang::CompoundStmt* AppendAndCreateCompoundStmt(clang::ASTContext& C, clang::Stmt* initial, clang::Stmt* S); - + /// Shorthand to issues a warning or error. template void EmitDiag(clang::Sema& semaRef, @@ -126,8 +126,8 @@ namespace clad { /// /// \param S /// \param namespc - /// \param shouldExist If true, then asserts that the specified namespace - /// is found. + /// \param shouldExist If true, then asserts that the specified namespace + /// is found. /// \param DC clang::NamespaceDecl* LookupNSD(clang::Sema& S, llvm::StringRef namespc, bool shouldExist, @@ -234,7 +234,7 @@ namespace clad { bool IsCladValueAndPushforwardType(clang::QualType T); - /// Returns a valid `SourceRange` to be used in places where clang + /// Returns a valid `SourceRange` to be used in places where clang /// requires a valid `SourceRange`. clang::SourceRange GetValidSRange(clang::Sema& semaRef); @@ -314,7 +314,7 @@ namespace clad { bool hasNonDifferentiableAttribute(const clang::Expr* E); /// FIXME: add documentation - std::vector GetInnermostReturnExpr(clang::Expr* E); + std::vector GetInnermostReturnExpr(const clang::Expr* E); } // namespace utils } diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 392241374..ef964a155 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -30,7 +30,7 @@ namespace clad { class ReverseModeVisitor : public clang::ConstStmtVisitor, public VisitorBase { - + private: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in @@ -52,6 +52,10 @@ namespace clad { Stmts m_Globals; //// A reference to the output parameter of the gradient function. clang::Expr* m_Result; + /// Based on To-Be-Recorded analysis performed before differentiation, + /// tells UsefulToStoreGlobal whether a variable with a given + /// SourceLocation has to be stored before being changed or not. + std::map m_ToBeRecorded; /// A flag indicating if the Stmt we are currently visiting is inside loop. bool isInsideLoop = false; /// Output variable of vector-valued function @@ -136,7 +140,7 @@ namespace clad { return m_Blocks.back(); else if (d == direction::reverse) return m_Reverse.back(); - else + else return m_EssentialReverse.back(); } /// Create new block. @@ -145,7 +149,7 @@ namespace clad { m_Blocks.push_back({}); else if (d == direction::reverse) m_Reverse.push_back({}); - else + else m_EssentialReverse.push_back({}); return getCurrentBlock(d); } @@ -228,11 +232,6 @@ namespace clad { forceDeclCreation, IS); } - /// Based on To-Be-Recorded analysis performed before differentiation, - /// tells UsefulToStoreGlobal whether a variable with a given - /// SourceLocation has to be stored before changed or not. - std::map m_ToBeRecorded; - /// For an expr E, decides if it is useful to store it in a global temporary /// variable and replace E's further usage by a reference to that variable /// to avoid recomputiation. @@ -429,7 +428,7 @@ namespace clad { clang::QualType xType); /// Allows to easily create and manage a counter for counting the number of - /// executed iterations of a loop. + /// executed iterations of a loop. /// /// It is required to save the number of executed iterations to use the /// same number of iterations in the reverse pass. @@ -448,11 +447,11 @@ namespace clad { /// for counter; otherwise, returns nullptr. clang::Expr* getPush() const { return m_Push; } - /// Returns `clad::pop(_t)` expression if clad tape is used for + /// Returns `clad::pop(_t)` expression if clad tape is used for /// for counter; otherwise, returns nullptr. clang::Expr* getPop() const { return m_Pop; } - /// Returns reference to the last object of the clad tape if clad tape + /// Returns reference to the last object of the clad tape if clad tape /// is used as the counter; otherwise returns reference to the counter /// variable. clang::Expr* getRef() const { return m_Ref; } @@ -494,11 +493,11 @@ namespace clad { /// This class modifies forward and reverse blocks of the loop /// body so that `break` and `continue` statements are correctly - /// handled. `break` and `continue` statements are handled by + /// handled. `break` and `continue` statements are handled by /// enclosing entire reverse block loop body in a switch statement /// and only executing the statements, with the help of case labels, - /// that were executed in the associated forward iteration. This is - /// determined by keeping track of which `break`/`continue` statement + /// that were executed in the associated forward iteration. This is + /// determined by keeping track of which `break`/`continue` statement /// was hit in which iteration and that in turn helps to determine which /// case label should be selected. /// @@ -526,7 +525,7 @@ namespace clad { /// \note `m_ControlFlowTape` is only initialized if the body contains /// `continue` or `break` statement. std::unique_ptr m_ControlFlowTape; - + /// Each `break` and `continue` statement is assigned a unique number, /// starting from 1, that is used as the case label corresponding to that `break`/`continue` /// statement. `m_CaseCounter` stores the value that was used for last @@ -565,7 +564,7 @@ namespace clad { /// control flow switch statement. clang::CaseStmt* GetNextCFCaseStmt(); - /// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` + /// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` /// expression, where `TapeRef` and `m_CurrentCounter` are replaced /// by their actual values respectively. clang::Stmt* CreateCFTapePushExprToCurrentCase(); diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index 978d86b18..c96634142 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -30,22 +30,30 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { union IdxOrMemberValue { const clang::FieldDecl* field; llvm::APInt index; - IdxOrMemberValue() {} + IdxOrMemberValue() : field(nullptr) {} ~IdxOrMemberValue() {} + IdxOrMemberValue(const IdxOrMemberValue&) = delete; + IdxOrMemberValue& operator=(const IdxOrMemberValue&) = delete; }; IdxOrMemberType type; IdxOrMemberValue val; IdxOrMember(const clang::FieldDecl* field) : type(IdxOrMemberType::FIELD) { val.field = field; } - IdxOrMember(llvm::APInt index) : type(IdxOrMemberType::INDEX) { + IdxOrMember(llvm::APInt&& index) : type(IdxOrMemberType::INDEX) { new (&val.index) llvm::APInt(index); } IdxOrMember(const IdxOrMember& other) : type(other.type) { + new (&val.index) llvm::APInt(); + *this = other; + } + IdxOrMember& operator=(const IdxOrMember& other) { + type = other.type; if (type == IdxOrMemberType::FIELD) val.field = other.val.field; else - new (&val.index) llvm::APInt(other.val.index); + val.index = other.val.index; + return *this; } }; @@ -71,39 +79,45 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// a row, which seems uncommon. It's worth considering analysing arrays as /// whole structures instead (just one VarData for the whole array). + struct VarData; + using ObjMap = std::unordered_map; + using ArrMap = std::unordered_map; + struct VarData { enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE }; union VarDataValue { bool fundData; - std::unordered_map objData; - std::unordered_map arrData; + /// objData, arrData are stored as pointers for VarDataValue to take + /// less space. + ObjMap* objData; + ArrMap* arrData; VarData* refData; - VarDataValue() {} - ~VarDataValue() {} + VarDataValue() : fundData(false) {} }; VarDataType type; VarDataValue val; bool isReferenced = false; - /// For non-fundamental type variables, all the child nodes have to be - /// deleted. + VarData() = default; + VarData(const VarData&) = delete; + VarData& operator=(const VarData&) = delete; + ~VarData() { if (type == OBJ_TYPE) { - for (auto pair : val.objData) { + for (auto& pair : *val.objData) { delete pair.second; } } else if (type == ARR_TYPE) { - for (auto pair : val.arrData) { + for (auto& pair : *val.arrData) { delete pair.second; } } } - /// Recursively sets all the leaves' bools to isReq. void setIsRequired(bool isReq = true); /// Returns true if there is at least one required to store node among /// child nodes. - bool findReq(); + bool findReq() const; /// Whenever an array element with a non-constant index is set to required /// this function is used to set to required all the array elements that /// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required @@ -122,7 +136,9 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// corresponding original nodes in case those are referenced (a referenced /// node is a child to multiple nodes, therefore, we need to make sure we /// don't make multiple copies of it). + VarData* copy(); VarData* copy(std::unordered_map& refVars); + void restoreRefs(std::unordered_map& refVars); }; /// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its @@ -138,7 +154,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Given an Expr* returns its corresponding VarData. VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false); /// Adds the field FD to objData. - void addField(std::unordered_map& objData, + void addField(std::unordered_map* objData, const FieldDecl* FD); /// Whenever an array element with a non-constant index is set to required /// this function is used to set to required all the array elements that @@ -164,7 +180,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { std::map TBRLocs; /// Stores VarsData for every branch in control flow (e.g. if-else statements, /// loops). - std::vector reqStack; + std::vector> reqStack; /// Stores modes in a stack (used to retrieve the old mode after entering /// a new one). std::vector modeStack; @@ -179,7 +195,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// The index of the innermost branch corresponding to a loop (used to handle /// break/continue statements). - size_t innermostLoopBranch = 0; + size_t innermostLoopLayer = 0; /// Tells if the current branch should be deleted instead of merged with /// others. This happens when the branch has a break/continue statement or a /// return expression in it. @@ -204,22 +220,28 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { void setIsRequired(const clang::Expr* E, bool isReq = true); //// Control Flow - /// Creates a new branch as a copy of the last used branch. - void addBranch(); - /// Merges the last into the one right before it and deletes it. - /// If keepNewVars==false, it removes all the variables that are present - /// in the last branch but not the other. If keepNewVars==true, all the new - /// variables are kept. - /// Note: The branch we are merging into is not supposed to have its own - /// local variables (this doesn't matter to the branch being merged). - void mergeAndDelete(bool keepNewVars = false); - /// Swaps the last two branches in the stack. - void swapLastPairOfBranches(); - /// Merges the current branch to a branch with a given index in the stack. - /// Current branch is NOT deleted. - /// Note: The branch we are merging into is not supposed to have its own - /// local variables (this doesn't matter to the branch being merged). - void mergeCurBranchTo(size_t targetBranchNum); + /// Returns the current branch. + VarsData& getCurBranch() { return reqStack.back().back(); } + /// Adds a new layer. + void addLayer() { reqStack.emplace_back(); } + /// Creates a new empty branch. + void addBranch() { reqStack.back().emplace_back(); } + /// Deletes the last branch. + void deleteBranch() { + for (auto& pair : getCurBranch()) + delete pair.second; + reqStack.back().pop_back(); + } + /// Merges the last layer into the one last branch on the previous layer + /// right and deletes the last layer. + void mergeLayer(); + /// Merges the last layer but, unlike the previous method, basically replaces + /// the last branch on the previous layer with the result of merging. After + /// that, removes the last layer. + void mergeLayerOnTop(); + /// Merges the branch with index targetBranch into a sourceBranchNum. + /// No branches are deleted. + void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch); /// Removes local variables from the current branch (uses localVarsStack). /// This is necessary when merging if-else branches. void removeLocalVars(); @@ -239,29 +261,29 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { void resetMode() { modeStack.pop_back(); } public: - // Constructor + /// Constructor TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) { modeStack.push_back(0); - reqStack.push_back(VarsData()); + addLayer(); + addBranch(); } - // Destructor + /// Destructor ~TBRAnalyzer() { - /// By the end of analysis, reqStack is supposed have just one branch - /// but it's better to iterate through it just to make sure there's no - /// memory leak. - for (auto& branch : reqStack) { - for (auto pair : branch) { - delete pair.second; - } - } + for (auto& layer : reqStack) + for (auto& branch : layer) + for (auto& pair : branch) + delete pair.second; } + /// Delete copy operator and constructor. + TBRAnalyzer(const TBRAnalyzer&) = delete; + TBRAnalyzer& operator=(const TBRAnalyzer&) = delete; + /// Returns the result of the whole analysis - const std::map getResult() { return TBRLocs; } + std::map getResult() { return TBRLocs; } /// Visitors - void Analyze(const clang::FunctionDecl* FD); void Visit(const clang::Stmt* stmt) { diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 2c06b87be..4c0bfff44 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -534,15 +534,11 @@ namespace clad { return false; } - std::vector GetInnermostReturnExpr(clang::Expr* E) { + std::vector GetInnermostReturnExpr(const clang::Expr* E) { struct Finder : public ConstStmtVisitor { std::vector m_return_exprs; - // Sema* m_Sema; - // ASTContext* m_Context; public: - Finder(/*Sema* S*/) /* : m_Sema(S), m_Context(S.getASTContext()) */ {} - std::vector Find(const clang::Expr* E) { Visit(E); return m_return_exprs; @@ -551,7 +547,7 @@ namespace clad { void VisitBinaryOperator(const clang::BinaryOperator* BO) { if (BO->isAssignmentOp() || BO->isCompoundAssignmentOp()) { Visit(BO->getLHS()); - } else if (BO->isCommaOp()) { + } else if (BO->getOpcode() == clang::BO_Comma) { /**/ } else { assert("Unexpected binary operator!!"); @@ -575,10 +571,8 @@ namespace clad { m_return_exprs.push_back(const_cast(DRE)); } - void VisitExpr(const clang::Expr* E) { - if (auto PE = dyn_cast(E)) { - Visit(PE->getSubExpr()); - } + void VisitParenExpr(const clang::ParenExpr* PE) { + Visit(PE->getSubExpr()); } void VisitMemberExpr(const clang::MemberExpr* ME) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 918a06d22..00774ed0b 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -455,7 +455,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivativeAndOverload ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, const DiffRequest& request) { - TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context); + auto* analyzer = new TBRAnalyzer(&m_Context); analyzer->Analyze(FD); m_ToBeRecorded = analyzer->getResult(); delete analyzer; @@ -571,7 +571,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } void ReverseModeVisitor::DifferentiateWithClad() { - TBRAnalyzer* analyzer = new TBRAnalyzer(&m_Context); + auto* analyzer = new TBRAnalyzer(&m_Context); analyzer->Analyze(m_Function); m_ToBeRecorded = analyzer->getResult(); delete analyzer; @@ -1897,10 +1897,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // call to gradient and call to original function. At this point, each arg // is either a simple expression or a reference to a temporary variable. // Therefore cloning it has constant complexity. - std::transform(std::begin(CallArgs), - std::end(CallArgs), - std::begin(CallArgs), - [this](Expr* E) { return Clone(E); }); // Recreate the original call expression. Expr* call = m_Sema .ActOnCallExpr(getCurrentScope(), @@ -1934,8 +1930,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = Visit(E, dfdx()); auto EStored = GlobalStoreAndRef(diff.getExpr()); if (EStored.getExpr() != diff.getExpr()) { - auto assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), - EStored.getExpr_dx()); + auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), + EStored.getExpr_dx()); if (isInsideLoop) addToCurrentBlock(EStored.getExpr(), direction::forward); addToCurrentBlock(assign, direction::reverse); @@ -1948,8 +1944,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = Visit(E, dfdx()); auto EStored = GlobalStoreAndRef(diff.getExpr()); if (EStored.getExpr() != diff.getExpr()) { - auto assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), - EStored.getExpr_dx()); + auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), + EStored.getExpr_dx()); if (isInsideLoop) addToCurrentBlock(EStored.getExpr(), direction::forward); addToCurrentBlock(assign, direction::reverse); @@ -2242,7 +2238,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto E : return_exprs) { Lstored = GlobalStoreAndRef(E); if (Lstored.getExpr() != E) { - auto assign = + auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, E, Lstored.getExpr_dx()); if (isInsideLoop) addToCurrentBlock(Lstored.getExpr(), direction::forward); @@ -2695,10 +2691,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto it = m_ToBeRecorded.find(B->getBeginLoc()); if (it == m_ToBeRecorded.end()) { return true; - } else { - return it->second; } - // return true; + return it->second; } // FIXME: Attach checkpointing. diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 80996c4db..366fb2ab4 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -8,11 +8,11 @@ void TBRAnalyzer::VarData::setIsRequired(bool isReq) { if (type == FUND_TYPE) { val.fundData = isReq; } else if (type == OBJ_TYPE) { - for (auto pair : val.objData) { + for (auto& pair : *val.objData) { pair.second->setIsRequired(isReq); } } else if (type == ARR_TYPE) { - for (auto pair : val.arrData) { + for (auto& pair : *val.arrData) { pair.second->setIsRequired(isReq); } } else if (type == REF_TYPE && val.refData) { @@ -24,22 +24,22 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { if (this->type == FUND_TYPE) { this->val.fundData = this->val.fundData || mergeData->val.fundData; } else if (this->type == OBJ_TYPE) { - for (auto pair : this->val.objData) { - pair.second->merge(mergeData->val.objData[pair.first]); + for (auto& pair : *this->val.objData) { + pair.second->merge((*mergeData->val.objData)[pair.first]); } } else if (this->type == ARR_TYPE) { /// FIXME: Currently non-constant indices are not supported in merging. - for (auto pair : this->val.arrData) { - auto it = mergeData->val.arrData.find(pair.first); - if (it != mergeData->val.arrData.end()) { + for (auto& pair : *this->val.arrData) { + auto it = mergeData->val.arrData->find(pair.first); + if (it != mergeData->val.arrData->end()) { pair.second->merge(it->second); } } - for (auto pair : mergeData->val.arrData) { - auto it = this->val.arrData.find(pair.first); - if (it == mergeData->val.arrData.end()) { + for (auto& pair : *mergeData->val.arrData) { + auto it = this->val.arrData->find(pair.first); + if (it == mergeData->val.arrData->end()) { std::unordered_map refVars; - this->val.arrData[pair.first] = pair.second->copy(refVars); + (*this->val.arrData)[pair.first] = pair.second->copy(refVars); } } } else if (this->type == REF_TYPE && this->val.refData) { @@ -47,59 +47,72 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { } } +TBRAnalyzer::VarData* TBRAnalyzer::VarData::copy() { + std::unordered_map refVars; + VarData* res = copy(refVars); + res->restoreRefs(refVars); + return res; +} + TBRAnalyzer::VarData* TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { - VarData* res; - + auto* res = new VarData(); /// The child node of a reference node should be copied only once. Hence, /// we use refVars to match original referenced nodes to corresponding copies. if (isReferenced) { - auto it = refVars.find(this); - if (it != refVars.end()) { - return it->second; - } else { - res = new VarData(); - refVars[this] = res; - } - } else { - res = new VarData(); + refVars[this] = res; } - res->type = this->type; - if (this->type == FUND_TYPE) { res->val.fundData = this->val.fundData; } else if (this->type == OBJ_TYPE) { - for (auto pair : this->val.objData) - res->val.objData[pair.first] = pair.second->copy(refVars); + res->val.objData = new ObjMap(); + for (auto& pair : *this->val.objData) + (*res->val.objData)[pair.first] = pair.second->copy(refVars); } else if (this->type == ARR_TYPE) { - for (auto pair : this->val.arrData) { - res->val.arrData[pair.first] = pair.second->copy(refVars); + res->val.arrData = new ArrMap(); + for (auto& pair : *this->val.arrData) { + (*res->val.arrData)[pair.first] = pair.second->copy(refVars); } } else if (this->type == REF_TYPE && this->val.refData) { - res->val.refData = this->val.refData->copy(refVars); - res->val.refData->isReferenced = true; + res->val.refData = this->val.refData; } - return res; } -bool TBRAnalyzer::VarData::findReq() { +void TBRAnalyzer::VarData::restoreRefs( + std::unordered_map& refVars) { + if (this->type == OBJ_TYPE) { + for (auto& pair : *val.objData) + pair.second->restoreRefs(refVars); + } else if (this->type == ARR_TYPE) { + for (auto& pair : *this->val.arrData) { + pair.second->restoreRefs(refVars); + } + } else if (this->type == REF_TYPE && this->val.refData) { + this->val.refData = refVars[this->val.refData]; + } +} + +bool TBRAnalyzer::VarData::findReq() const { if (type == FUND_TYPE) { return val.fundData; } else if (type == OBJ_TYPE) { - for (auto pair : val.objData) { - if (pair.second->findReq()) + for (auto& pair : *val.objData) { + if (pair.second->findReq()) { return true; + } } } else if (type == ARR_TYPE) { - for (auto pair : val.arrData) { - if (pair.second->findReq()) + for (auto& pair : *val.arrData) { + if (pair.second->findReq()) { return true; + } } } else if (type == REF_TYPE && val.refData) { - if (val.refData->findReq()) + if (val.refData->findReq()) { return true; + } } return false; } @@ -113,23 +126,23 @@ void TBRAnalyzer::VarData::overlay( --i; IdxOrMember& curIdxOrMember = IdxAndMemberSequence[i]; if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::FIELD) { - val.objData[curIdxOrMember.val.field]->overlay(IdxAndMemberSequence, i); + (*val.objData)[curIdxOrMember.val.field]->overlay(IdxAndMemberSequence, i); } else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) { auto idx = curIdxOrMember.val.index; - if (idx == llvm::APInt(32, -1, true)) { - for (auto pair : val.arrData) { + if (idx == llvm::APInt(2, -1, true)) { + for (auto& pair : *val.arrData) { pair.second->overlay(IdxAndMemberSequence, i); } } else { - val.arrData[idx]->overlay(IdxAndMemberSequence, i); + (*val.arrData)[idx]->overlay(IdxAndMemberSequence, i); } } } TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, bool addNonConstIdx) { - if (auto FD = dyn_cast(ME->getMemberDecl())) { - auto base = ME->getBase(); + if (const auto* FD = dyn_cast(ME->getMemberDecl())) { + const auto* base = ME->getBase(); VarData* baseData = getExprVarData(base); /// If the VarData is ref type just go to the VarData being referenced. if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) { @@ -141,8 +154,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, /// FUND_TYPE might be set by default earlier. if (baseData->type == VarData::VarDataType::FUND_TYPE) { baseData->type = VarData::VarDataType::OBJ_TYPE; - baseData->val.objData = - std::unordered_map(); + baseData->val.objData = new ObjMap(); } /// if non-const index was found and it is not supposed to be added just /// return the current VarData*. @@ -151,14 +163,14 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, auto& baseObjData = baseData->val.objData; /// Add the current field if it was not added previously - if (baseObjData.find(FD) == baseObjData.end()) { - VarData* FDData = new VarData(); - baseObjData[FD] = FDData; + if (baseObjData->find(FD) == baseObjData->end()) { + (*baseObjData)[FD] = new VarData(); + auto* FDData = (*baseObjData)[FD]; FDData->type = VarData::VarDataType::UNDEFINED; return FDData; } - return baseObjData[FD]; + return (*baseObjData)[FD]; } return nullptr; } @@ -166,17 +178,17 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, TBRAnalyzer::VarData* TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, bool addNonConstIdx) { - auto idxExpr = ASE->getIdx(); + const auto* idxExpr = ASE->getIdx(); llvm::APInt idx; - if (auto IL = dyn_cast(idxExpr)) { + if (const auto* IL = dyn_cast(idxExpr)) { idx = IL->getValue(); } else { nonConstIndexFound = true; /// Non-const indices are represented with -1. - idx = llvm::APInt(32, -1, true); + idx = llvm::APInt(2, -1, true); } - auto base = ASE->getBase()->IgnoreImpCasts(); + const auto* base = ASE->getBase()->IgnoreImpCasts(); VarData* baseData = getExprVarData(base); /// If the VarData is ref type just go to the VarData being referenced. if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) { @@ -188,8 +200,7 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, /// FUND_TYPE might be set by default earlier. if (baseData->type == VarData::VarDataType::FUND_TYPE) { baseData->type = VarData::VarDataType::ARR_TYPE; - baseData->val.arrData = - std::unordered_map(); + baseData->val.arrData = new ArrMap(); } /// if non-const index was found and it is not supposed to be added just @@ -198,16 +209,16 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, return baseData; auto& baseArrData = baseData->val.arrData; - auto itEnd = baseArrData.end(); + auto itEnd = baseArrData->end(); /// Add the current index if it was not added previously - if (baseArrData.find(idx) == itEnd) { - VarData* idxData = new VarData(); - baseArrData[idx] = idxData; + if (baseArrData->find(idx) == itEnd) { + (*baseArrData)[idx] = new VarData(); + auto* idxData = (*baseArrData)[idx]; /// Since -1 represents non-const indices, whenever we add a new index we /// have to copy the VarData of -1's element (if an element with undefined /// index was used this might be our current element). - auto it = baseArrData.find(llvm::APInt(32, -1, true)); + auto it = baseArrData->find(llvm::APInt(2, -1, true)); if (it != itEnd) { std::unordered_map dummy; idxData = it->second->copy(dummy); @@ -217,7 +228,7 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, return idxData; } - return baseArrData[idx]; + return (*baseArrData)[idx]; } TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, @@ -226,21 +237,28 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, /// x would be implicitly casted with the * operator). E = E->IgnoreImpCasts(); VarData* EData; - if (auto DRE = dyn_cast(E)) { - if (auto VD = dyn_cast(DRE->getDecl())) { - EData = reqStack.back()[VD]; + if (isa(E) || isa(E)) { + const VarDecl* VD = nullptr; + /// ``this`` does not have a declaration so it is represented with nullptr. + if (const auto* DRE = dyn_cast(E)) + VD = dyn_cast(DRE->getDecl()); + /// The index i is shifted since otherwise the last value would be i=-1 + /// and size_t can only take positive values. + for (size_t i = reqStack.size(); i > 0; --i) { + auto& branch = reqStack[i - 1].back(); + const auto it = branch.find(VD); + if (it != branch.end()) { + EData = it->second; + break; + } } } - if (auto ME = dyn_cast(E)) { + if (const auto* ME = dyn_cast(E)) { EData = getMemberVarData(ME, addNonConstIdx); } - if (auto ASE = dyn_cast(E)) { + if (const auto* ASE = dyn_cast(E)) { EData = getArrSubVarData(ASE, addNonConstIdx); } - /// 'this' pointer is represented as a nullptr. - if (auto TE = dyn_cast(E)) { - EData = reqStack.back()[nullptr]; - } /// If the type of this VarData was not defined previously set it to /// FUND_TYPE. @@ -253,28 +271,25 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, return EData; } -void TBRAnalyzer::addField( - std::unordered_map& objData, - const FieldDecl* FD) { - auto varType = FD->getType(); - VarData* data = new VarData(); - objData[FD] = data; +void TBRAnalyzer::addField(ObjMap* objData, const FieldDecl* FD) { + const auto varType = FD->getType(); + (*objData)[FD] = new VarData(); + VarData* data = (*objData)[FD]; if (varType->isReferenceType()) { data->type = VarData::VarDataType::REF_TYPE; data->val.refData = nullptr; } else if (utils::isArrayOrPointerType(varType)) { data->type = VarData::VarDataType::ARR_TYPE; - data->val.arrData = - std::unordered_map(); + data->val.arrData = new ArrMap(); } else if (varType->isBuiltinType()) { data->type = VarData::VarDataType::FUND_TYPE; data->val.fundData = false; } else if (varType->isRecordType()) { data->type = VarData::VarDataType::OBJ_TYPE; - auto recordDecl = varType->getAs()->getDecl(); + const auto* recordDecl = varType->getAs()->getDecl(); auto& newObjData = data->val.objData; - for (auto field : recordDecl->fields()) { + for (const auto* field : recordDecl->fields()) { addField(newObjData, field); } } @@ -288,15 +303,15 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { /// Unwrap the given expression to a vector of indices and fields. while (cond) { E = E->IgnoreImplicit(); - if (auto ASE = dyn_cast(E)) { - if (auto IL = dyn_cast(ASE->getIdx())) { + if (const auto* ASE = dyn_cast(E)) { + if (const auto* IL = dyn_cast(ASE->getIdx())) { IdxAndMemberSequence.push_back(IdxOrMember(IL->getValue())); } else { - IdxAndMemberSequence.push_back(IdxOrMember(llvm::APInt(32, -1, true))); + IdxAndMemberSequence.push_back(IdxOrMember(llvm::APInt(2, -1, true))); } E = ASE->getBase(); - } else if (auto ME = dyn_cast(E)) { - if (auto FD = dyn_cast(ME->getMemberDecl())) { + } else if (const auto* ME = dyn_cast(E)) { + if (const auto* FD = dyn_cast(ME->getMemberDecl())) { IdxAndMemberSequence.push_back(IdxOrMember(FD)); } E = ME->getBase(); @@ -305,16 +320,21 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { } else return; } + /// Overlay on all the VarData's recursively. - if (auto VD = dyn_cast(innermostDRE->getDecl())) - reqStack.back()[VD]->overlay(IdxAndMemberSequence, - IdxAndMemberSequence.size()); + if (const auto* VD = dyn_cast(innermostDRE->getDecl())) { + getCurBranch()[VD]->overlay(IdxAndMemberSequence, + IdxAndMemberSequence.size()); + } } void TBRAnalyzer::addVar(const clang::VarDecl* VD) { - // FIXME: this marks the SourceLocation of DeclStmt which doesn't work for - // declarations with multiple VarDecls. - auto& curBranch = reqStack.back(); + /// If a declaration is passed a second time (meaning it is inside a loop), + /// treat it as an assigment operation. + /// FIXME: this marks the SourceLocation of DeclStmt which doesn't work for + /// declarations with multiple VarDecls. + size_t len = reqStack.size(); + auto& curBranch = reqStack[len - 1].back(); if (curBranch.find(VD) != curBranch.end()) { auto& VDData = curBranch[VD]; if (VDData->type == VarData::VarDataType::FUND_TYPE) { @@ -322,46 +342,54 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { (deleteCurBranch ? false : VDData->findReq()); } } + /// The index here is shifted by one since otherwise the loop would end with + /// i=-1 and size_t is positive only. + for (size_t i = len - 1; i > 0; --i) { + auto& branch = reqStack[i - 1].back(); + if (branch.find(VD) != branch.end()) { + curBranch[VD] = branch[VD]->copy(); + break; + } + } if (!localVarsStack.empty()) { localVarsStack.back().push_back(VD); } - auto varType = VD->getType(); - VarData* data = new VarData(); - curBranch[VD] = data; + const auto varType = VD->getType(); + curBranch[VD] = new VarData(); + VarData* data = curBranch[VD]; if (varType->isReferenceType()) { data->type = VarData::VarDataType::REF_TYPE; data->val.refData = nullptr; } else if (utils::isArrayOrPointerType(varType)) { - if (auto pointerType = llvm::dyn_cast(varType)) { + if (const auto pointerType = llvm::dyn_cast(varType)) { /// FIXME: If the pointer points to an object we represent it with a /// OBJ_TYPE VarData*. - auto pointeeType = pointerType->getPointeeType().getTypePtrOrNull(); + const auto pointeeType = pointerType->getPointeeType().getTypePtrOrNull(); if (pointeeType && pointeeType->isRecordType()) { data->type = VarData::VarDataType::OBJ_TYPE; - auto recordDecl = pointeeType->getAs()->getDecl(); + const auto* recordDecl = pointeeType->getAs()->getDecl(); auto& objData = data->val.objData; - objData = std::unordered_map(); - for (auto field : recordDecl->fields()) { + objData = new ObjMap(); + for (const auto* field : recordDecl->fields()) { addField(objData, field); } return; } } data->type = VarData::VarDataType::ARR_TYPE; - data->val.arrData = - std::unordered_map(); + data->val.arrData = new ArrMap(); } else if (varType->isBuiltinType()) { data->type = VarData::VarDataType::FUND_TYPE; data->val.fundData = false; } else if (varType->isRecordType()) { data->type = VarData::VarDataType::OBJ_TYPE; - auto recordDecl = varType->getAs()->getDecl(); + const auto* recordDecl = varType->getAs()->getDecl(); auto& objData = data->val.objData; - objData = std::unordered_map(); - for (auto field : recordDecl->fields()) { + objData = new ObjMap(); + for (const auto* field : recordDecl->fields()) { addField(objData, field); } } @@ -385,46 +413,87 @@ void TBRAnalyzer::markLocation(const clang::Expr* E) { TBRLocs[E->getBeginLoc()] = !deleteCurBranch; } -void TBRAnalyzer::addBranch() { - VarsData& curBranch = reqStack.back(); - VarsData newBranch; - std::unordered_map refVars; - for (auto pair : curBranch) { - newBranch[pair.first] = pair.second->copy(refVars); +void TBRAnalyzer::mergeLayer() { + size_t len = reqStack.size(); + auto& removedLayer = reqStack[len - 1]; + auto& curBranch = reqStack[len - 2].back(); + + for (auto& removedBranch : removedLayer) { + for (auto& pair : curBranch) { + auto it = removedBranch.find(pair.first); + if (it != removedBranch.end()) { + pair.second->merge(it->second); + } + } + for (auto& pair : removedBranch) { + auto it = curBranch.find(pair.first); + if (it == curBranch.end()) { + delete curBranch[pair.first]; + curBranch[pair.first] = pair.second; + } else { + delete pair.second; + } + } } - reqStack.push_back(newBranch); + reqStack.pop_back(); } -void TBRAnalyzer::mergeAndDelete(bool keepNewVars) { - auto removedBranch = reqStack.back(); - reqStack.pop_back(); - auto& curBranch = reqStack.back(); +void TBRAnalyzer::mergeLayerOnTop() { + size_t len = reqStack.size(); + auto& removedLayer = reqStack[len - 1]; - if (keepNewVars) { - for (auto& pair : curBranch) { - removedBranch[pair.first]->merge(pair.second); - delete pair.second; - pair.second = removedBranch[pair.first]; + if (removedLayer.empty()) { + reqStack.pop_back(); + return; + } + + auto& curBranch = reqStack[len - 2].back(); + + /// First, we merge every branch on the layer with the first one there. + auto branchIter = removedLayer.begin(); + auto branchIterEnd = removedLayer.end(); + auto& firstBranch = *branchIter; + while ((++branchIter) != branchIterEnd) { + for (auto& pair : firstBranch) { + auto elemIter = branchIter->find(pair.first); + if (elemIter != branchIter->end()) { + pair.second->merge(elemIter->second); + } } - } else { - for (auto pair : curBranch) { - pair.second->merge(removedBranch[pair.first]); - delete removedBranch[pair.first]; + for (auto& pair : *branchIter) { + auto elemIter = firstBranch.find(pair.first); + if (elemIter == firstBranch.end()) { + delete firstBranch[pair.first]; + firstBranch[pair.first] = pair.second; + } else { + delete pair.second; + } } } -} -void TBRAnalyzer::swapLastPairOfBranches() { - size_t s = reqStack.size(); - std::swap(reqStack[s - 1], reqStack[s - 2]); -} + /// Second, we place it on top of the branch on the previous layer's last + /// branch. + for (auto& pair : firstBranch) { + delete curBranch[pair.first]; + curBranch[pair.first] = pair.second; + } -void TBRAnalyzer::mergeCurBranchTo(size_t targetBranchNum) { - auto& targetBranch = reqStack[targetBranchNum]; - auto& curBranch = reqStack.back(); + reqStack.pop_back(); +} +void TBRAnalyzer::mergeBranchTo(size_t sourceBranchNum, + VarsData& targetBranch) { for (auto& pair : targetBranch) { - pair.second->merge(curBranch[pair.first]); + /// Index i is shifted by one since otherwise its last value could be -1 + /// and size_t is only positive. + for (size_t i = sourceBranchNum + 1; i > 0; --i) { + auto& sourceBranch = reqStack[i - 1].back(); + auto it = sourceBranch.find(pair.first); + if (it != sourceBranch.end()) { + pair.second->merge(it->second); + break; + } + } } } @@ -447,14 +516,15 @@ void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { void TBRAnalyzer::Analyze(const FunctionDecl* FD) { /// If we are analysing a method add a VarData for 'this' pointer (it is /// represented with nullptr). + if (isa(FD)) { - VarData* data = new VarData(); - reqStack.back()[nullptr] = data; - data->type = VarData::VarDataType::OBJ_TYPE; + auto& thisData = getCurBranch()[nullptr]; + thisData = new VarData(); + thisData->type = VarData::VarDataType::OBJ_TYPE; auto recordDecl = dyn_cast(FD->getParent()); - auto& objData = data->val.objData; - objData = std::unordered_map(); - for (auto field : recordDecl->fields()) { + auto& objData = thisData->val.objData; + objData = new ObjMap(); + for (const auto* field : recordDecl->fields()) { addField(objData, field); } } @@ -471,8 +541,8 @@ void TBRAnalyzer::VisitCompoundStmt(const CompoundStmt* CS) { } void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { - if (auto VD = dyn_cast(DRE->getDecl())) { - auto& curBranch = reqStack.back(); + if (const auto* VD = dyn_cast(DRE->getDecl())) { + auto& curBranch = getCurBranch(); // FIXME: this is only necessary to ensure global variables are added. // It doesn't make any sense to first add variables when visiting DeclStmt // and then checking if they were added while visiting DeclRefExpr. @@ -480,7 +550,7 @@ void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { addVar(VD); } - if (auto E = dyn_cast(DRE)) { + if (const auto* E = dyn_cast(DRE)) { setIsRequired(E); } } @@ -511,20 +581,21 @@ void TBRAnalyzer::VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE) { } void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { - for (auto D : DS->decls()) { - if (auto VD = dyn_cast(D)) { + for (const auto* D : DS->decls()) { + if (const auto* VD = dyn_cast(D)) { addVar(VD); - if (clang::Expr* init = VD->getInit()) { + if (const clang::Expr* init = VD->getInit()) { setMode(Mode::markingMode); Visit(init); resetMode(); - auto VDExpr = reqStack.back()[VD]; + auto& VDExpr = getCurBranch()[VD]; /// if the declared variable is ref type attach its VarData* to the /// VarData* of the RHS variable. if (VDExpr->type == VarData::VarDataType::REF_TYPE) { - auto RHSExpr = getExprVarData(utils::GetInnermostReturnExpr(init)[0]); + auto* RHSExpr = + getExprVarData(utils::GetInnermostReturnExpr(init)[0]); VDExpr->val.refData = RHSExpr; - RHSExpr->isReferenced = VDExpr; + RHSExpr->isReferenced = true; } } } @@ -537,25 +608,23 @@ void TBRAnalyzer::VisitConditionalOperator( Visit(CO->getCond()); resetMode(); + addLayer(); addBranch(); Visit(CO->getTrueExpr()); - swapLastPairOfBranches(); + addBranch(); Visit(CO->getFalseExpr()); - mergeAndDelete(); + mergeLayerOnTop(); } void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { - auto opCode = BinOp->getOpcode(); - auto L = BinOp->getLHS(); - auto R = BinOp->getRHS(); + const auto opCode = BinOp->getOpcode(); + const auto* L = BinOp->getLHS(); + const auto* R = BinOp->getRHS(); /// Addition is not able to create any differential influence by itself so /// markingMode should be left as it is. Similarly, addition does not affect /// linearity so nonLinearMode shouldn't be changed as well. The same applies /// to subtraction. - if (opCode == BO_Add) { - Visit(L); - Visit(R); - } else if (opCode == BO_Sub) { + if (opCode == BO_Add || opCode == BO_Sub) { Visit(L); Visit(R); } else if (opCode == BO_Mul) { @@ -616,8 +685,8 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { Visit(R); resetMode(); } - auto return_exprs = utils::GetInnermostReturnExpr(L); - for (auto innerExpr : return_exprs) { + const auto return_exprs = utils::GetInnermostReturnExpr(L); + for (const auto* innerExpr : return_exprs) { /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be used changed. markLocation(innerExpr); @@ -640,16 +709,16 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { } void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { - auto opCode = UnOp->getOpcode(); - Expr* E = UnOp->getSubExpr(); + const auto opCode = UnOp->getOpcode(); + const Expr* E = UnOp->getSubExpr(); Visit(E); if (opCode == UO_PostInc || opCode == UO_PostDec || opCode == UO_PreInc || opCode == UO_PreDec) { // FIXME: this doesn't support all the possible references /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be used in this operation. - auto innerExprs = utils::GetInnermostReturnExpr(E); - for (auto innerExpr : innerExprs) { + const auto innerExprs = utils::GetInnermostReturnExpr(E); + for (const auto* innerExpr : innerExprs) { /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be changed. markLocation(innerExpr); @@ -662,9 +731,9 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { } void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { - auto cond = If->getCond(); - auto condVarDecl = If->getConditionVariable(); - auto condInit = If->getInit(); + const auto* cond = If->getCond(); + const auto* condVarDecl = If->getConditionVariable(); + const auto* condInit = If->getInit(); /// We have to separated analyse then-block and else-block and then merge /// them together. First, we make a copy of the current branch and analyse @@ -678,14 +747,14 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { /// ... - - /// ... - - addBranch(); + const auto* thenBranch = If->getThen(); + const auto* elseBranch = If->getElse(); + + localVarsStack.emplace_back(); + addLayer(); - bool thenBranchNotDeleted = true; - bool elseBranchNotDeleted = true; - auto thenBranch = If->getThen(); - auto elseBranch = If->getElse(); if (thenBranch) { - localVarsStack.push_back(std::vector()); + addBranch(); Visit(cond); if (condVarDecl) addVar(condVarDecl); @@ -698,20 +767,13 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { if (deleteCurBranch) { /// This section is performed if this branch had break/continue/return /// and, therefore, shouldn't be merged. - reqStack.pop_back(); + deleteBranch(); deleteCurBranch = false; - thenBranchNotDeleted = false; - } else { - /// We have to remove local variables from then-branch to later merge the - /// else-branch into it. - removeLocalVars(); - localVarsStack.pop_back(); } } if (elseBranch) { - if (thenBranchNotDeleted) - swapLastPairOfBranches(); + addBranch(); Visit(cond); if (condVarDecl) addVar(condVarDecl); @@ -721,23 +783,23 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { resetMode(); } Visit(elseBranch); - if (deleteCurBranch && thenBranchNotDeleted) { + if (deleteCurBranch) { /// This section is performed if this branch had break/continue/return /// and, therefore, shouldn't be merged. - reqStack.pop_back(); + deleteBranch(); deleteCurBranch = false; - elseBranchNotDeleted = false; } } - if (thenBranchNotDeleted && elseBranchNotDeleted) - mergeAndDelete(); + mergeLayerOnTop(); + removeLocalVars(); + localVarsStack.pop_back(); } void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { - auto body = WS->getBody(); - auto cond = WS->getCond(); - size_t backupILB = innermostLoopBranch; + const auto* body = WS->getBody(); + const auto* cond = WS->getCond(); + size_t backupILB = innermostLoopLayer; bool backupFLP = firstLoopPass; bool backupDCB = deleteCurBranch; /// Let's assume we have a section of code structured like this @@ -766,45 +828,51 @@ void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { /// ... - Visit(cond); + addLayer(); + addBranch(); + addBranch(); + + addLayer(); addBranch(); addBranch(); /// First pass - innermostLoopBranch = reqStack.size() - 2; + innermostLoopLayer = reqStack.size() - 1; firstLoopPass = true; if (body) Visit(body); if (deleteCurBranch) { - reqStack.pop_back(); + deleteBranch(); deleteCurBranch = backupDCB; } else { Visit(cond); - mergeAndDelete(/*keepNewVars=*/true); } + --innermostLoopLayer; + mergeLayer(); + mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); /// Second pass - --innermostLoopBranch; firstLoopPass = false; if (body) Visit(body); if (deleteCurBranch) - reqStack.pop_back(); + deleteBranch(); else { Visit(cond); - mergeAndDelete(); } + mergeLayer(); - innermostLoopBranch = backupILB; + innermostLoopLayer = backupILB; firstLoopPass = backupFLP; deleteCurBranch = backupDCB; } void TBRAnalyzer::VisitForStmt(const clang::ForStmt* FS) { - auto body = FS->getBody(); - auto condVar = FS->getConditionVariable(); - auto init = FS->getInit(); - auto cond = FS->getCond(); - auto incr = FS->getInc(); - size_t backupILB = innermostLoopBranch; + const auto* body = FS->getBody(); + const auto* condVar = FS->getConditionVariable(); + auto* init = FS->getInit(); + const auto* cond = FS->getCond(); + const auto* incr = FS->getInc(); + size_t backupILB = innermostLoopLayer; bool backupFLP = firstLoopPass; bool backupDCB = deleteCurBranch; /// The logic here is virtually the same as with while-loop. Take a look at @@ -816,50 +884,55 @@ void TBRAnalyzer::VisitForStmt(const clang::ForStmt* FS) { } if (cond) Visit(cond); + addLayer(); + addBranch(); addBranch(); if (condVar) addVar(condVar); + addLayer(); + addBranch(); addBranch(); /// First pass - innermostLoopBranch = reqStack.size() - 2; + innermostLoopLayer = reqStack.size() - 1; firstLoopPass = true; if (body) Visit(body); if (deleteCurBranch) { - reqStack.pop_back(); + deleteBranch(); deleteCurBranch = backupDCB; } else { if (incr) Visit(incr); if (cond) Visit(cond); - mergeAndDelete(/*keepNewVars=*/true); } + --innermostLoopLayer; + mergeLayer(); + mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); /// Second pass - --innermostLoopBranch; firstLoopPass = false; if (body) Visit(body); if (incr) Visit(incr); if (deleteCurBranch) - reqStack.pop_back(); + deleteBranch(); else { if (cond) Visit(cond); - mergeAndDelete(); } + mergeLayer(); - innermostLoopBranch = backupILB; + innermostLoopLayer = backupILB; firstLoopPass = backupFLP; deleteCurBranch = backupDCB; } void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { - auto body = DS->getBody(); - auto cond = DS->getCond(); - size_t backupILB = innermostLoopBranch; + const auto* body = DS->getBody(); + const auto* cond = DS->getCond(); + size_t backupILB = innermostLoopLayer; bool backupFLP = firstLoopPass; bool backupDCB = deleteCurBranch; @@ -870,10 +943,14 @@ void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { /// having two loop branches is necessary for handling continue statements /// so we can't just remove one of them. + addLayer(); + addBranch(); + addBranch(); + addLayer(); addBranch(); addBranch(); /// First pass - innermostLoopBranch = reqStack.size() - 2; + innermostLoopLayer = reqStack.size() - 2; firstLoopPass = true; if (body) Visit(body); @@ -882,21 +959,58 @@ void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { deleteCurBranch = backupDCB; } else { Visit(cond); - mergeAndDelete(/*keepNewVars=*/true); + mergeLayer(); } /// Second pass - --innermostLoopBranch; + --innermostLoopLayer; firstLoopPass = false; if (body) Visit(body); Visit(cond); if (deleteCurBranch) { reqStack.pop_back(); - mergeAndDelete(); + mergeLayer(); } - innermostLoopBranch = backupILB; + innermostLoopLayer = backupILB; + firstLoopPass = backupFLP; + deleteCurBranch = backupDCB; + + addLayer(); + addBranch(); + addBranch(); + + addLayer(); + addBranch(); + addBranch(); + /// First pass + innermostLoopLayer = reqStack.size() - 1; + firstLoopPass = true; + if (body) + Visit(body); + if (deleteCurBranch) { + deleteBranch(); + deleteCurBranch = backupDCB; + } else { + Visit(cond); + } + --innermostLoopLayer; + mergeLayer(); + mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); + + /// Second pass + firstLoopPass = false; + if (body) + Visit(body); + if (deleteCurBranch) + deleteBranch(); + else { + Visit(cond); + } + mergeLayer(); + + innermostLoopLayer = backupILB; firstLoopPass = backupFLP; deleteCurBranch = backupDCB; } @@ -915,11 +1029,18 @@ void TBRAnalyzer::VisitContinueStmt(const clang::ContinueStmt* CS) { /// followed by another iteration. We have to either add an additional branch /// or find a better solution. (However, this bug will matter only in really /// rare cases) - mergeCurBranchTo(innermostLoopBranch); + + auto& targetLayer1 = reqStack[innermostLoopLayer]; + auto& targetBranch1 = targetLayer1[targetLayer1.size() - 2]; + size_t sourceBranchNum = reqStack.size() - 1; + mergeBranchTo(sourceBranchNum, targetBranch1); /// After the continue statement, this branch cannot be followed by any other /// code so we can delete it. - if (firstLoopPass) - mergeCurBranchTo(innermostLoopBranch - 1); + if (firstLoopPass) { + auto& targetLayer2 = reqStack[innermostLoopLayer - 1]; + auto& targetBranch2 = targetLayer2[targetLayer2.size() - 2]; + mergeBranchTo(sourceBranchNum, targetBranch2); + } deleteCurBranch = true; } @@ -928,8 +1049,11 @@ void TBRAnalyzer::VisitBreakStmt(const clang::BreakStmt* BS) { /// ... - - /// And so this break could be the end of this loop. So we have to merge /// the current branch into the first branch on the diagram. - if (!firstLoopPass) - mergeCurBranchTo(innermostLoopBranch); + if (!firstLoopPass) { + auto& targetLayer = reqStack[innermostLoopLayer]; + auto& targetBranch = targetLayer[targetLayer.size() - 2]; + mergeBranchTo(reqStack.size() - 1, targetBranch); + } /// After the break statement, this branch cannot be followed by any other /// code so we can delete it. deleteCurBranch = true; @@ -939,17 +1063,17 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { /// FIXME: Currently TBR analysis just stops here and assumes that all the /// variables passed by value/reference are used/used and changed. Analysis /// could proceed to the function to analyse data flow inside it. - auto FD = CE->getDirectCallee(); + auto* FD = CE->getDirectCallee(); setMode(Mode::markingMode | Mode::nonLinearMode); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { - clang::Expr* arg = const_cast(CE->getArg(i)); + const clang::Expr* arg = CE->getArg(i); bool passByRef = FD->getParamDecl(i)->getType()->isReferenceType(); setMode(Mode::markingMode | Mode::nonLinearMode); Visit(arg); resetMode(); - auto B = arg->IgnoreParenImpCasts(); + const auto* B = arg->IgnoreParenImpCasts(); // FIXME: this supports only DeclRefExpr - auto innerExpr = utils::GetInnermostReturnExpr(arg); + const auto innerExpr = utils::GetInnermostReturnExpr(arg); if (passByRef) { /// Mark SourceLocation as required for ref-type arguments. if (isa(B) || isa(B)) { @@ -970,15 +1094,15 @@ void TBRAnalyzer::VisitCXXConstructExpr(const clang::CXXConstructExpr* CE) { /// variables passed by value/reference are used/used and changed. Analysis /// could proceed to the constructor to analyse data flow inside it. /// FIXME: add support for default values - auto FD = CE->getConstructor(); + auto* FD = CE->getConstructor(); setMode(Mode::markingMode | Mode::nonLinearMode); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { - auto arg = CE->getArg(i); + const auto* arg = CE->getArg(i); bool passByRef = FD->getParamDecl(i)->getType()->isReferenceType(); setMode(Mode::markingMode | Mode::nonLinearMode); Visit(arg); resetMode(); - auto B = arg->IgnoreParenImpCasts(); + const auto* B = arg->IgnoreParenImpCasts(); // FIXME: this supports only DeclRefExpr if (passByRef) { /// Mark SourceLocation as required for ref-type arguments. @@ -1001,6 +1125,9 @@ void TBRAnalyzer::VisitMemberExpr(const clang::MemberExpr* ME) { void TBRAnalyzer::VisitArraySubscriptExpr( const clang::ArraySubscriptExpr* ASE) { + setMode(0); + Visit(ASE->getBase()); + resetMode(); setIsRequired(dyn_cast(ASE)); setMode(Mode::markingMode | Mode::nonLinearMode); Visit(ASE->getIdx()); @@ -1009,15 +1136,15 @@ void TBRAnalyzer::VisitArraySubscriptExpr( void TBRAnalyzer::VisitInitListExpr(const clang::InitListExpr* ILE) { setMode(0); - for (auto init : ILE->inits()) { + for (auto* init : ILE->inits()) { Visit(init); } resetMode(); } void TBRAnalyzer::removeLocalVars() { - auto& curBranch = reqStack.back(); - for (auto VD : localVarsStack.back()) + auto& curBranch = getCurBranch(); + for (const auto* VD : localVarsStack.back()) curBranch.erase(VD); } diff --git a/test/Enzyme/DifferentCladEnzymeDerivatives.C b/test/Enzyme/DifferentCladEnzymeDerivatives.C index c3cd617dd..f45e8e1d3 100644 --- a/test/Enzyme/DifferentCladEnzymeDerivatives.C +++ b/test/Enzyme/DifferentCladEnzymeDerivatives.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out | FileCheck %s +// RUN: %cladclang %s -I%S/../../include -oDifferentCladEnzymeDerivatives.out // RUN: ./DifferentCladEnzymeDerivatives.out // CHECK-NOT: {{.*error|warning|note:.*}} // REQUIRES: Enzyme @@ -32,6 +32,6 @@ double foo(double x, double y){ // CHECK-NEXT: } int main(){ - auto grad = clad::gradient(foo); + auto grad = clad::gradient(foo); auto gradEnzyme = clad::gradient(foo); -} \ No newline at end of file +} diff --git a/test/Enzyme/LoopsReverseModeComparisonWithClad.C b/test/Enzyme/LoopsReverseModeComparisonWithClad.C index 73a3e94ec..2a8e699b7 100644 --- a/test/Enzyme/LoopsReverseModeComparisonWithClad.C +++ b/test/Enzyme/LoopsReverseModeComparisonWithClad.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oEnzymeLoops.out 2>&1 | FileCheck %s +// RUN: %cladclang %s -I%S/../../include -oEnzymeLoops.out // RUN: ./EnzymeLoops.out | FileCheck -check-prefix=CHECK-EXEC %s // REQUIRES: Enzyme // CHECK-NOT: {{.*error|warning|note:.*}} @@ -92,7 +92,7 @@ double f6 (double i, double j) { double fn7(double i, double j) { double a = 0; int counter = 3; - while (counter--) + while (counter--) a += i*i + j; return a; } @@ -229,7 +229,7 @@ double fn13(double i, double j) { double fn14(double i, double j) { int choice = 5; double res = 0; - for (int counter=0; counter