diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index ac233700d..c81b0cd2b 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -122,13 +122,16 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); if (noHiddenParam) { MutableArrayRef FDparam = FD->parameters(); + m_Varied = true; + m_Marking = true; for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); TraverseStmt(par); m_VariedDecls.insert(FDparam[i]); } + m_Varied = false; + m_Marking = false; } - m_Varied = true; return true; } diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index a526a6d58..b4c7018a2 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -624,8 +624,9 @@ namespace clad { return true; if (!m_ActivityRunInfo.HasAnalysisRun) { - for (const auto& dParam : DVI) - m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); + std::copy(Function->param_begin(), Function->param_end(), + std::inserter(m_ActivityRunInfo.ToBeRecorded, + m_ActivityRunInfo.ToBeRecorded.end())); VariedAnalyzer analyzer(Function->getASTContext(), m_ActivityRunInfo.ToBeRecorded); diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index e1b5cb35c..1a28f83be 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -241,6 +241,29 @@ double f7(double x){ // CHECK-NEXT: } // CHECK-NEXT: } +double f8_1(double v, double u){ + return v; +} +double f8(double x){ + double c = f8_1(1, 1); + double f = f8_1(x, 1); + return f; +} +// CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u); + +// CHECK-NEXT: void f8_grad(double x, double *_d_x) { +// CHECK-NEXT: double c = f8_1(1, 1); +// CHECK-NEXT: double _d_f = 0.; +// CHECK-NEXT: double f = f8_1(x, 1); +// CHECK-NEXT: _d_f += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: double _r1 = 0.; +// CHECK-NEXT: f8_1_pullback(x, 1, _d_f, &_r0, &_r1); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -257,6 +280,7 @@ int main(){ TEST(f5, 3);// CHECK-EXEC: {0.00} TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} + TEST(f8, 3);// CHECK-EXEC: {1.00} } // CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { @@ -271,3 +295,7 @@ int main(){ // CHECK-NEXT: *_d_v += 2 * _d_n; // CHECK-NEXT: *_d_u += 2 * _d_k; // CHECK-NEXT: } + +// CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { +// CHECK-NEXT: *_d_v += _d_y; +// CHECK-NEXT: } \ No newline at end of file