Skip to content

Commit

Permalink
Create a central point to generate gradient postfixes. Use postfix wi…
Browse files Browse the repository at this point in the history
…th pullbacks when necessary.
  • Loading branch information
PetroZarytskyi committed Jul 22, 2024
1 parent 9cd45b5 commit a0eba51
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 37 deletions.
36 changes: 30 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,36 @@ namespace clad {
// 'MultiplexExternalRMVSource.h' file
MultiplexExternalRMVSource* m_ExternalSource = nullptr;
clang::Expr* m_Pullback = nullptr;
const char* funcPostfix() const {
if (m_DiffReq.Mode == DiffMode::jacobian)
return "_jac";
if (m_DiffReq.use_enzyme)
return "_grad_enzyme";
return "_grad";

static std::string diffParamsPostfix(const DiffRequest& request) {
std::string postfix;
const DiffInputVarsInfo& DVI = request.DVI;
std::size_t numParams = request->getNumParams();
// If Jacobian is asked, the last parameter is the result parameter
// and should be ignored
if (request.Mode == DiffMode::jacobian)
numParams -= 1;
// To be consistent with older tests, nothing is appended to 'f_grad' if
// we differentiate w.r.t. all the parameters at once.
if (DVI.size() != numParams)
for (const auto& dParam : DVI) {
const clang::ValueDecl* arg = dParam.param;
const auto* begin = request->param_begin();
const auto* end = std::next(begin, numParams);
const auto* it = std::find(begin, end, arg);
auto idx = std::distance(begin, it);
postfix += ('_' + std::to_string(idx));
}
return postfix;
}

static std::string funcPostfix(const DiffRequest& request) {
std::string postfix = "_pullback";
if (request.Mode == DiffMode::jacobian)
postfix = "_jac";
if (request.use_enzyme)
postfix = "_grad_enzyme";
return postfix + diffParamsPostfix(request);
}

/// Removes the local const qualifiers from a QualType and returns a new
Expand Down
30 changes: 5 additions & 25 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
FunctionDecl*
ReverseModeVisitor::CreateGradientOverload(unsigned numExtraParams) {
auto gradientParams = m_Derivative->parameters();
std::string name = m_DiffReq.BaseFunctionName + funcPostfix();
std::string name =
m_DiffReq.BaseFunctionName + "_grad" + diffParamsPostfix(m_DiffReq);
IdentifierInfo* II = &m_Context.Idents.get(name);
DeclarationNameInfo DNI(II, noLoc);
// Calculate the total number of parameters that would be required for
Expand Down Expand Up @@ -303,28 +304,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

auto derivativeBaseName = request.BaseFunctionName;
std::string gradientName = derivativeBaseName + "_pullback";
// To be consistent with older tests, nothing is appended to 'f_grad' if
// we differentiate w.r.t. all the parameters at once.
if (request.Mode == DiffMode::jacobian) {
gradientName = derivativeBaseName + "_jac";
// If Jacobian is asked, the last parameter is the result parameter
// and should be ignored
if (args.size() != FD->getNumParams()-1){
for (const auto* arg : args) {
const auto* const it =
std::find(FD->param_begin(), FD->param_end() - 1, arg);
auto idx = std::distance(FD->param_begin(), it);
gradientName += ('_' + std::to_string(idx));
}
}
} else if (args.size() != FD->getNumParams()) {
for (const auto* arg : args) {
const auto* it = std::find(FD->param_begin(), FD->param_end(), arg);
auto idx = std::distance(FD->param_begin(), it);
gradientName += ('_' + std::to_string(idx));
}
}
std::string gradientName = derivativeBaseName + funcPostfix(m_DiffReq);

IdentifierInfo* II = &m_Context.Idents.get(gradientName);
DeclarationNameInfo name(II, noLoc);
Expand Down Expand Up @@ -508,8 +488,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(request, args);

auto derivativeName =
utils::ComputeEffectiveFnName(m_DiffReq.Function) + "_pullback";
auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) +
funcPostfix(m_DiffReq);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
Expand Down
12 changes: 6 additions & 6 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ double add(double a, double* b) {
}


//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a);
//CHECK: void add_pullback_0(double a, double *b, double _d_y, double *_d_a);

//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a, double *_d_b);

Expand Down Expand Up @@ -633,7 +633,7 @@ double fn17 (double x, double* y) {
//CHECK-NEXT: double _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x = 0;
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: add_pullback(x, y, _r_d0, &_r0);
//CHECK-NEXT: add_pullback_0(x, y, _r_d0, &_r0);
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -682,7 +682,7 @@ double weighted_sum(double* x, const double* w) {
return w[0] * x[0] + w[1] * x[1];
}

// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x);
// CHECK: void weighted_sum_pullback_0(double *x, const double *w, double _d_y, double *_d_x);

double fn20(double* x, const double* w) {
const double* auxW = w + 1;
Expand All @@ -691,7 +691,7 @@ double fn20(double* x, const double* w) {

// CHECK: void fn20_pullback_0(double *x, const double *w, double _d_y, double *_d_x) {
// CHECK-NEXT: const double *auxW = w + 1;
// CHECK-NEXT: weighted_sum_pullback(x, auxW, _d_y, _d_x);
// CHECK-NEXT: weighted_sum_pullback_0(x, auxW, _d_y, _d_x);
// CHECK-NEXT: }

double ptrRef(double*& ptr_ref) {
Expand Down Expand Up @@ -1067,7 +1067,7 @@ double sq_defined_later(double x) {
//CHECK-NEXT: }
//CHECK-NEXT: }

//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a) {
//CHECK: void add_pullback_0(double a, double *b, double _d_y, double *_d_a) {
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: }

Expand All @@ -1089,7 +1089,7 @@ double sq_defined_later(double x) {
// CHECK-NEXT: *_d_x += _d_y;
// CHECK-NEXT: }

// CHECK: void weighted_sum_pullback(double *x, const double *w, double _d_y, double *_d_x) {
// CHECK: void weighted_sum_pullback_0(double *x, const double *w, double _d_y, double *_d_x) {
// CHECK-NEXT: {
// CHECK-NEXT: _d_x[0] += w[0] * _d_y;
// CHECK-NEXT: _d_x[1] += w[1] * _d_y;
Expand Down

0 comments on commit a0eba51

Please sign in to comment.