From 5b66f0e909c0f88ce9fdd5f093255f7cc229b010 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 28 Mar 2024 17:23:20 +0200 Subject: [PATCH] Don't create adjoint pullback parameters for non-differentiable arguments. In most cases, we support non-differentiable variables (i.e. variables that don't have adjoints). Currently, the major application for them is non-independent array parameters. For instance, for ``` double fn17 (double x, double* y) { return x; } ``` a request ``clad::gradient(fn17, "x");`` will produce ``` void fn17_grad_0(double x, double *y, double *_d_x) { goto _label0; _label0: ; } ``` In this example, ``y`` does not have an adjoint. However, calling a function of ``y`` produces an error. After these changes, the non-differentiability of ``y`` is propagated to the pullback. Fixes #765. --- lib/Differentiator/DiffPlanner.cpp | 4 ++ lib/Differentiator/ReverseModeVisitor.cpp | 20 ++++---- test/Gradient/FunctionCalls.C | 58 +++++++++++++++++++++++ 3 files changed, 74 insertions(+), 8 deletions(-) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 2654562e0..061d747a3 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -273,6 +273,10 @@ namespace clad { } void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) { + // Diff info for pullbacks is generated automatically, + // its parameters are not provided by the user. + if (Mode == DiffMode::experimental_pullback) + return; DVI.clear(); auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index aaeec0d74..9672ea190 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -471,8 +471,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, assert(m_Function && "Must not be null."); DiffParams args{}; - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - + if (!request.DVI.empty()) + for (const auto& dParam : request.DVI) + args.push_back(dParam.param); + else + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); #ifndef NDEBUG bool isStaticMethod = utils::IsStaticMethod(FD); assert((!args.empty() || !isStaticMethod) && @@ -1509,9 +1512,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // statements there later. std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); - // FIXME: We should add instructions for handling non-differentiable - // arguments. Currently we are implicitly assuming function call only - // contains differentiable arguments. bool isCXXOperatorCall = isa(CE); for (std::size_t i = static_cast(isCXXOperatorCall), @@ -1729,9 +1729,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "corresponding dfdx()."); } - DerivedCallArgs.insert(DerivedCallArgs.end(), - DerivedCallOutputArgs.begin(), - DerivedCallOutputArgs.end()); + for (Expr* arg : DerivedCallOutputArgs) + if (arg) + DerivedCallArgs.push_back(arg); pullbackCallArgs = DerivedCallArgs; if (pullback) @@ -1782,6 +1782,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = enableTBR; + bool isaMethod = isa(FD); + for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) + if (DerivedCallOutputArgs[i + isaMethod]) + pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 364b8be4f..e522d4227 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -758,6 +758,59 @@ double fn16(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double add(double a, double* b) { + return a + b[0]; +} + +//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: *_d_a += _d_y; +//CHECK-NEXT: } + +//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a, double *_d_b) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: *_d_a += _d_y; +//CHECK-NEXT: _d_b[0] += _d_y; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn17 (double x, double* y) { + x = add(x, y); + x = add(x, &x); + return x; +} + +//CHECK: void fn17_grad_0(double x, double *y, double *_d_x) { +//CHECK-NEXT: double _t0; +//CHECK-NEXT: double _t1; +//CHECK-NEXT: _t0 = x; +//CHECK-NEXT: x = add(x, y); +//CHECK-NEXT: _t1 = x; +//CHECK-NEXT: x = add(x, &x); +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: *_d_x += 1; +//CHECK-NEXT: { +//CHECK-NEXT: x = _t1; +//CHECK-NEXT: double _r_d1 = *_d_x; +//CHECK-NEXT: *_d_x -= _r_d1; +//CHECK-NEXT: double _r1 = 0; +//CHECK-NEXT: add_pullback(x, &x, _r_d1, &_r1, &*_d_x); +//CHECK-NEXT: *_d_x += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: x = _t0; +//CHECK-NEXT: double _r_d0 = *_d_x; +//CHECK-NEXT: *_d_x -= _r_d0; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: add_pullback(x, y, _r_d0, &_r0); +//CHECK-NEXT: *_d_x += _r0; +//CHECK-NEXT: } +//CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i