Skip to content

Commit

Permalink
Keep track of whether a request is immediate in DiffRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed Oct 23, 2024
1 parent 767eb13 commit fa19625
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct DiffRequest {
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
bool EnableVariedAnalysis = false;
/// A flag specifying whether this differentiation is to be used
/// in immediate contexts.
bool ImmediateMode = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down
2 changes: 2 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,8 @@ namespace clad {
request.RequestedDerivativeOrder = derivative_order;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode))
request.ImmediateMode = true;
if (enable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
Expand Down
7 changes: 5 additions & 2 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,18 @@ namespace clad {
DiffCollector collector(DGR, CladEnabledRange, m_DiffRequestGraph, S,
opts);

#if CLANG_VERSION_MAJOR > 16
for (DiffRequest& request : m_DiffRequestGraph.getNodes()) {
if (!request.Function->isImmediateFunction() &&
!request.Function->isConstexpr())
if (!request.ImmediateMode ||
(!request.Function->isConstexpr() &&
!request.Function->isImmediateFunction()))
continue;

m_DiffRequestGraph.setCurrentProcessingNode(request);
ProcessDiffRequest(request);
m_DiffRequestGraph.markCurrentNodeProcessed();
}
#endif

// We could not delay the processing of derivatives, inform act as if each
// call is final. That would still have vgvassilev/clad#248 unresolved.
Expand Down

0 comments on commit fa19625

Please sign in to comment.