From a0eba513ac32f70bcfc84e7ff442054bc8d64e18 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 1 Jul 2024 00:14:15 +0300 Subject: [PATCH] Create a central point to generate gradient postfixes. Use postfix with pullbacks when necessary. --- .../clad/Differentiator/ReverseModeVisitor.h | 36 +++++++++++++++---- lib/Differentiator/ReverseModeVisitor.cpp | 30 +++------------- test/Gradient/FunctionCalls.C | 12 +++---- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 768a55292..7d7f7e4c7 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -73,12 +73,36 @@ namespace clad { // 'MultiplexExternalRMVSource.h' file MultiplexExternalRMVSource* m_ExternalSource = nullptr; clang::Expr* m_Pullback = nullptr; - const char* funcPostfix() const { - if (m_DiffReq.Mode == DiffMode::jacobian) - return "_jac"; - if (m_DiffReq.use_enzyme) - return "_grad_enzyme"; - return "_grad"; + + static std::string diffParamsPostfix(const DiffRequest& request) { + std::string postfix; + const DiffInputVarsInfo& DVI = request.DVI; + std::size_t numParams = request->getNumParams(); + // If Jacobian is asked, the last parameter is the result parameter + // and should be ignored + if (request.Mode == DiffMode::jacobian) + numParams -= 1; + // To be consistent with older tests, nothing is appended to 'f_grad' if + // we differentiate w.r.t. all the parameters at once. + if (DVI.size() != numParams) + for (const auto& dParam : DVI) { + const clang::ValueDecl* arg = dParam.param; + const auto* begin = request->param_begin(); + const auto* end = std::next(begin, numParams); + const auto* it = std::find(begin, end, arg); + auto idx = std::distance(begin, it); + postfix += ('_' + std::to_string(idx)); + } + return postfix; + } + + static std::string funcPostfix(const DiffRequest& request) { + std::string postfix = "_pullback"; + if (request.Mode == DiffMode::jacobian) + postfix = "_jac"; + if (request.use_enzyme) + postfix = "_grad_enzyme"; + return postfix + diffParamsPostfix(request); } /// Removes the local const qualifiers from a QualType and returns a new diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index be3b7fce0..003471442 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -118,7 +118,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, FunctionDecl* ReverseModeVisitor::CreateGradientOverload(unsigned numExtraParams) { auto gradientParams = m_Derivative->parameters(); - std::string name = m_DiffReq.BaseFunctionName + funcPostfix(); + std::string name = + m_DiffReq.BaseFunctionName + "_grad" + diffParamsPostfix(m_DiffReq); IdentifierInfo* II = &m_Context.Idents.get(name); DeclarationNameInfo DNI(II, noLoc); // Calculate the total number of parameters that would be required for @@ -303,28 +304,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } auto derivativeBaseName = request.BaseFunctionName; - std::string gradientName = derivativeBaseName + "_pullback"; - // To be consistent with older tests, nothing is appended to 'f_grad' if - // we differentiate w.r.t. all the parameters at once. - if (request.Mode == DiffMode::jacobian) { - gradientName = derivativeBaseName + "_jac"; - // If Jacobian is asked, the last parameter is the result parameter - // and should be ignored - if (args.size() != FD->getNumParams()-1){ - for (const auto* arg : args) { - const auto* const it = - std::find(FD->param_begin(), FD->param_end() - 1, arg); - auto idx = std::distance(FD->param_begin(), it); - gradientName += ('_' + std::to_string(idx)); - } - } - } else if (args.size() != FD->getNumParams()) { - for (const auto* arg : args) { - const auto* it = std::find(FD->param_begin(), FD->param_end(), arg); - auto idx = std::distance(FD->param_begin(), it); - gradientName += ('_' + std::to_string(idx)); - } - } + std::string gradientName = derivativeBaseName + funcPostfix(m_DiffReq); IdentifierInfo* II = &m_Context.Idents.get(gradientName); DeclarationNameInfo name(II, noLoc); @@ -508,8 +488,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterParsingDiffArgs(request, args); - auto derivativeName = - utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback"; + auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + + funcPostfix(m_DiffReq); auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); auto paramTypes = ComputeParamTypes(args); diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 2c8ee76e2..a0690432d 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -602,7 +602,7 @@ double add(double a, double* b) { } -//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a); +//CHECK: void add_pullback_0(double a, double *b, double _d_y, double *_d_a); //CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a, double *_d_b); @@ -633,7 +633,7 @@ double fn17 (double x, double* y) { //CHECK-NEXT: double _r_d0 = *_d_x; //CHECK-NEXT: *_d_x = 0; //CHECK-NEXT: double _r0 = 0; -//CHECK-NEXT: add_pullback(x, y, _r_d0, &_r0); +//CHECK-NEXT: add_pullback_0(x, y, _r_d0, &_r0); //CHECK-NEXT: *_d_x += _r0; //CHECK-NEXT: } //CHECK-NEXT: } @@ -682,7 +682,7 @@ double weighted_sum(double* x, const double* w) { return w[0] * x[0] + w[1] * x[1]; } -// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x); +// CHECK: void weighted_sum_pullback_0(double *x, const double *w, double _d_y, double *_d_x); double fn20(double* x, const double* w) { const double* auxW = w + 1; @@ -691,7 +691,7 @@ double fn20(double* x, const double* w) { // CHECK: void fn20_pullback_0(double *x, const double *w, double _d_y, double *_d_x) { // CHECK-NEXT: const double *auxW = w + 1; -// CHECK-NEXT: weighted_sum_pullback(x, auxW, _d_y, _d_x); +// CHECK-NEXT: weighted_sum_pullback_0(x, auxW, _d_y, _d_x); // CHECK-NEXT: } double ptrRef(double*& ptr_ref) { @@ -1067,7 +1067,7 @@ double sq_defined_later(double x) { //CHECK-NEXT: } //CHECK-NEXT: } -//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a) { +//CHECK: void add_pullback_0(double a, double *b, double _d_y, double *_d_a) { //CHECK-NEXT: *_d_a += _d_y; //CHECK-NEXT: } @@ -1089,7 +1089,7 @@ double sq_defined_later(double x) { // CHECK-NEXT: *_d_x += _d_y; // CHECK-NEXT: } -// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x) { +// CHECK: void weighted_sum_pullback_0(double *x, const double *w, double _d_y, double *_d_x) { // CHECK-NEXT: { // CHECK-NEXT: _d_x[0] += w[0] * _d_y; // CHECK-NEXT: _d_x[1] += w[1] * _d_y;