Skip to content

Commit

Permalink
Add support for parameter specification in Varied Analysis
Browse files Browse the repository at this point in the history
This PR fixes Varied Analysis to work with parameter specification and improves function calls support; it's no longer assumed that every function is varied by default.
  • Loading branch information
Max Andriychuk authored and vgvassilev committed Oct 27, 2024
1 parent 72f0edd commit 6c45be3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
5 changes: 4 additions & 1 deletion lib/Differentiator/ActivityAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,16 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) {
bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams());
if (noHiddenParam) {
MutableArrayRef<ParmVarDecl*> FDparam = FD->parameters();
m_Varied = true;
m_Marking = true;
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* par = CE->getArg(i);
TraverseStmt(par);
m_VariedDecls.insert(FDparam[i]);
}
m_Varied = false;
m_Marking = false;
}
m_Varied = true;
return true;
}

Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ namespace clad {
return true;

if (!m_ActivityRunInfo.HasAnalysisRun) {
for (const auto& dParam : DVI)
m_ActivityRunInfo.ToBeRecorded.insert(cast<VarDecl>(dParam.param));
std::copy(Function->param_begin(), Function->param_end(),
std::inserter(m_ActivityRunInfo.ToBeRecorded,
m_ActivityRunInfo.ToBeRecorded.end()));

VariedAnalyzer analyzer(Function->getASTContext(),
m_ActivityRunInfo.ToBeRecorded);
Expand Down
28 changes: 28 additions & 0 deletions test/Analyses/ActivityReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,29 @@ double f7(double x){
// CHECK-NEXT: }
// CHECK-NEXT: }

double f8_1(double v, double u){
return v;
}
double f8(double x){
double c = f8_1(1, 1);
double f = f8_1(x, 1);
return f;
}
// CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u);

// CHECK-NEXT: void f8_grad(double x, double *_d_x) {
// CHECK-NEXT: double c = f8_1(1, 1);
// CHECK-NEXT: double _d_f = 0.;
// CHECK-NEXT: double f = f8_1(x, 1);
// CHECK-NEXT: _d_f += 1;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: f8_1_pullback(x, 1, _d_f, &_r0, &_r1);
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_va>(F);\
Expand All @@ -257,6 +280,7 @@ int main(){
TEST(f5, 3);// CHECK-EXEC: {0.00}
TEST(f6, 3);// CHECK-EXEC: {0.00}
TEST(f7, 3);// CHECK-EXEC: {1.00}
TEST(f8, 3);// CHECK-EXEC: {1.00}
}

// CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) {
Expand All @@ -271,3 +295,7 @@ int main(){
// CHECK-NEXT: *_d_v += 2 * _d_n;
// CHECK-NEXT: *_d_u += 2 * _d_k;
// CHECK-NEXT: }

// CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) {
// CHECK-NEXT: *_d_v += _d_y;
// CHECK-NEXT: }

0 comments on commit 6c45be3

Please sign in to comment.