Skip to content

Commit

Permalink
Move the TBR analysis in DiffRequest and provide an interface to quer…
Browse files Browse the repository at this point in the history
…y it.

The idea of this change is to separate the analysis steps from the building step.
That would help us detect unsupported cases and provide useful diagnostics
without crashing. The centralization will help with enabling the future activity
analysis in other modes, too.

Partially addresses #721.
  • Loading branch information
vgvassilev committed Jul 18, 2024
1 parent 3daeb51 commit 8fc2d5d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 42 deletions.
29 changes: 20 additions & 9 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,30 @@
#include "clad/Differentiator/ParseDiffArgsTypes.h"

namespace clang {
class ASTContext;
class CallExpr;
class CompilerInstance;
class DeclGroupRef;
class Expr;
class FunctionDecl;
class ParmVarDecl;
class Sema;
class Type;
class CallExpr;
class CompilerInstance;
class DeclGroupRef;
class Expr;
class FunctionDecl;
class ParmVarDecl;
class Sema;
class Type;
} // namespace clang

namespace clad {

/// A struct containing information about request to differentiate a function.
struct DiffRequest {
private:
/// Based on To-Be-Recorded analysis performed before differentiation, tells
/// UsefulToStoreGlobal whether a variable with a given SourceLocation has to
/// be stored before being changed or not.
mutable struct TbrRunInfo {
std::set<clang::SourceLocation> ToBeRecorded;
bool HasAnalysisRun = false;
} m_TbrRunInfo;

public:
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
Expand Down Expand Up @@ -118,6 +127,8 @@ struct DiffRequest {
res += "__TBR";
return res;
}

bool shouldBeRecorded(clang::Expr* E) const;
};

using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
25 changes: 25 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "clad/Differentiator/DiffPlanner.h"

#include "TBRAnalyzer.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/SourceManager.h"
Expand Down Expand Up @@ -559,6 +561,29 @@ namespace clad {
return;
}

bool DiffRequest::shouldBeRecorded(Expr* E) const {
assert(EnableTBRAnalysis && "TBR not enabled!");

if (!isa<DeclRefExpr>(E) && !isa<ArraySubscriptExpr>(E) &&
!isa<MemberExpr>(E))
return true;

// FIXME: currently, we allow all pointer operations to be stored.
// This is not correct, but we need to implement a more advanced analysis
// to determine which pointer operations are useful to store.
if (E->getType()->isPointerType())
return true;

if (!m_TbrRunInfo.HasAnalysisRun) {
TBRAnalyzer analyzer(Function->getASTContext(),
m_TbrRunInfo.ToBeRecorded);
analyzer.Analyze(Function);
m_TbrRunInfo.HasAnalysisRun = true;
}
auto found = m_TbrRunInfo.ToBeRecorded.find(E->getBeginLoc());
return found != m_TbrRunInfo.ToBeRecorded.end();
}

bool DiffCollector::VisitCallExpr(CallExpr* E) {
// Check if we should look into this.
// FIXME: Generated code does not usually have valid source locations.
Expand Down
35 changes: 7 additions & 28 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
if (request.EnableTBRAnalysis) {
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(FD);
m_ToBeRecorded = analyzer.getResult();
}

// FIXME: Duplication of external source here is a workaround
// for the two 'Derive's being different functions.
if (m_ExternalSource)
Expand Down Expand Up @@ -603,12 +597,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

void ReverseModeVisitor::DifferentiateWithClad() {
if (m_DiffReq.EnableTBRAnalysis) {
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(m_DiffReq.Function);
m_ToBeRecorded = analyzer.getResult();
}

llvm::ArrayRef<ParmVarDecl*> paramsRef = m_Derivative->parameters();

// create derived variables for parameters which are not part of
Expand Down Expand Up @@ -3065,27 +3053,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return UsefulToStoreGlobal(UO->getSubExpr());
return true;
}
// We lack context to decide if this is useful to store or not. In the
// current system that should have been decided by the parent expression.
// FIXME: Here will be the entry point of the advanced activity analysis.
if (isa<DeclRefExpr>(B) || isa<ArraySubscriptExpr>(B) ||
isa<MemberExpr>(B)) {
// If TBR analysis is off, assume E is useful to store.
if (!m_DiffReq.EnableTBRAnalysis)
return true;
// FIXME: currently, we allow all pointer operations to be stored.
// This is not correct, but we need to implement a more advanced analysis
// to determine which pointer operations are useful to store.
if (E->getType()->isPointerType())
return true;
auto found = m_ToBeRecorded.find(B->getBeginLoc());
return found != m_ToBeRecorded.end();
}

// FIXME: Attach checkpointing.
if (isa<CallExpr>(B))
return false;

// FIXME: Here will be the entry point of the advanced activity analysis.

// Check if the expression was marked as to-be recorded by an analysis.
if (m_DiffReq.EnableTBRAnalysis)
return m_DiffReq.shouldBeRecorded(B);

// Assume E is useful to store.
return true;
}

Expand Down
8 changes: 3 additions & 5 deletions lib/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
enum Mode { kMarkingMode = 1, kNonLinearMode = 2 };
/// Tells if the variable at a given location is required to store. Basically,
/// is the result of analysis.
std::set<clang::SourceLocation> m_TBRLocs;
std::set<clang::SourceLocation>& m_TBRLocs;

/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
Expand Down Expand Up @@ -287,7 +287,8 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {

public:
/// Constructor
TBRAnalyzer(ASTContext& Context) : m_Context(Context) {
TBRAnalyzer(ASTContext& Context, std::set<clang::SourceLocation>& Locs)
: m_TBRLocs(Locs), m_Context(Context) {
m_ModeStack.push_back(0);
}

Expand All @@ -300,9 +301,6 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
TBRAnalyzer(const TBRAnalyzer&&) = delete;
TBRAnalyzer& operator=(const TBRAnalyzer&&) = delete;

/// Returns the result of the whole analysis
std::set<clang::SourceLocation> getResult() { return m_TBRLocs; }

/// Visitors
void Analyze(const clang::FunctionDecl* FD);

Expand Down

0 comments on commit 8fc2d5d

Please sign in to comment.