diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 5da49f13f..ef29b7246 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -8,21 +8,30 @@ #include "clad/Differentiator/ParseDiffArgsTypes.h" namespace clang { - class ASTContext; - class CallExpr; - class CompilerInstance; - class DeclGroupRef; - class Expr; - class FunctionDecl; - class ParmVarDecl; - class Sema; - class Type; +class CallExpr; +class CompilerInstance; +class DeclGroupRef; +class Expr; +class FunctionDecl; +class ParmVarDecl; +class Sema; +class Type; } // namespace clang namespace clad { /// A struct containing information about request to differentiate a function. struct DiffRequest { +private: + /// Based on To-Be-Recorded analysis performed before differentiation, tells + /// UsefulToStoreGlobal whether a variable with a given SourceLocation has to + /// be stored before being changed or not. + mutable struct TbrRunInfo { + std::set ToBeRecorded; + bool HasAnalysisRun = false; + } m_TbrRunInfo; + +public: /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -118,6 +127,8 @@ struct DiffRequest { res += "__TBR"; return res; } + + bool shouldBeRecorded(clang::Expr* E) const; }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index b5e7e1ffd..f74e92c4a 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -1,5 +1,7 @@ #include "clad/Differentiator/DiffPlanner.h" +#include "TBRAnalyzer.h" + #include "clang/AST/ASTContext.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/SourceManager.h" @@ -559,6 +561,29 @@ namespace clad { return; } + bool DiffRequest::shouldBeRecorded(Expr* E) const { + assert(EnableTBRAnalysis && "TBR not enabled!"); + + if (!isa(E) && !isa(E) && + !isa(E)) + return true; + + // FIXME: currently, we allow all pointer operations to be stored. + // This is not correct, but we need to implement a more advanced analysis + // to determine which pointer operations are useful to store. + if (E->getType()->isPointerType()) + return true; + + if (!m_TbrRunInfo.HasAnalysisRun) { + TBRAnalyzer analyzer(Function->getASTContext(), + m_TbrRunInfo.ToBeRecorded); + analyzer.Analyze(Function); + m_TbrRunInfo.HasAnalysisRun = true; + } + auto found = m_TbrRunInfo.ToBeRecorded.find(E->getBeginLoc()); + return found != m_TbrRunInfo.ToBeRecorded.end(); + } + bool DiffCollector::VisitCallExpr(CallExpr* E) { // Check if we should look into this. // FIXME: Generated code does not usually have valid source locations. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ad2a85df2..2f0638df6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -475,12 +475,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivativeAndOverload ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, const DiffRequest& request) { - if (request.EnableTBRAnalysis) { - TBRAnalyzer analyzer(m_Context); - analyzer.Analyze(FD); - m_ToBeRecorded = analyzer.getResult(); - } - // FIXME: Duplication of external source here is a workaround // for the two 'Derive's being different functions. if (m_ExternalSource) @@ -603,12 +597,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } void ReverseModeVisitor::DifferentiateWithClad() { - if (m_DiffReq.EnableTBRAnalysis) { - TBRAnalyzer analyzer(m_Context); - analyzer.Analyze(m_DiffReq.Function); - m_ToBeRecorded = analyzer.getResult(); - } - llvm::ArrayRef paramsRef = m_Derivative->parameters(); // create derived variables for parameters which are not part of @@ -3065,27 +3053,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return UsefulToStoreGlobal(UO->getSubExpr()); return true; } - // We lack context to decide if this is useful to store or not. In the - // current system that should have been decided by the parent expression. - // FIXME: Here will be the entry point of the advanced activity analysis. - if (isa(B) || isa(B) || - isa(B)) { - // If TBR analysis is off, assume E is useful to store. - if (!m_DiffReq.EnableTBRAnalysis) - return true; - // FIXME: currently, we allow all pointer operations to be stored. - // This is not correct, but we need to implement a more advanced analysis - // to determine which pointer operations are useful to store. - if (E->getType()->isPointerType()) - return true; - auto found = m_ToBeRecorded.find(B->getBeginLoc()); - return found != m_ToBeRecorded.end(); - } // FIXME: Attach checkpointing. if (isa(B)) return false; + // FIXME: Here will be the entry point of the advanced activity analysis. + + // Check if the expression was marked as to-be recorded by an analysis. + if (m_DiffReq.EnableTBRAnalysis) + return m_DiffReq.shouldBeRecorded(B); + + // Assume E is useful to store. return true; } diff --git a/lib/Differentiator/TBRAnalyzer.h b/lib/Differentiator/TBRAnalyzer.h index ce13ce9ba..0603e0621 100644 --- a/lib/Differentiator/TBRAnalyzer.h +++ b/lib/Differentiator/TBRAnalyzer.h @@ -224,7 +224,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor { enum Mode { kMarkingMode = 1, kNonLinearMode = 2 }; /// Tells if the variable at a given location is required to store. Basically, /// is the result of analysis. - std::set m_TBRLocs; + std::set& m_TBRLocs; /// Stores modes in a stack (used to retrieve the old mode after entering /// a new one). @@ -287,7 +287,8 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor { public: /// Constructor - TBRAnalyzer(ASTContext& Context) : m_Context(Context) { + TBRAnalyzer(ASTContext& Context, std::set& Locs) + : m_TBRLocs(Locs), m_Context(Context) { m_ModeStack.push_back(0); } @@ -300,9 +301,6 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor { TBRAnalyzer(const TBRAnalyzer&&) = delete; TBRAnalyzer& operator=(const TBRAnalyzer&&) = delete; - /// Returns the result of the whole analysis - std::set getResult() { return m_TBRLocs; } - /// Visitors void Analyze(const clang::FunctionDecl* FD);