Skip to content

Commit

Permalink
Improve finding of higher order custom derivatives
Browse files Browse the repository at this point in the history
Before this change, even if the user has provided a custom derivative of a
custom derivative function, Clad wasn't able to find it.
  • Loading branch information
vaithak authored and vgvassilev committed Jun 10, 2024
1 parent 56a1879 commit 2cae75f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 59 deletions.
29 changes: 6 additions & 23 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ namespace clad {
/// \returns type with namespace specifier added.
clang::QualType AddNamespaceSpecifier(clang::Sema& semaRef, clang::ASTContext& C, clang::QualType QT);

/// Finds declaration context associated with the DC1::DC2.
/// Finds declaration context associated with the DC1::DC2, but doesn't
/// replicate the common part of the declaration contexts.
/// For example, consider DC1 corresponds to the following declaration
/// context:
///
Expand All @@ -109,8 +110,10 @@ namespace clad {
///
/// and DC2 corresponds to the following declaration context:
/// ```
/// namespace A {
/// namespace B {}
/// namespace custom_derivatives {
/// namespace A {
/// namespace B {}
/// }
/// }
/// ```
/// then the function returns declartion context that correponds to
Expand All @@ -137,26 +140,6 @@ namespace clad {
bool shouldExist,
clang::DeclContext* DC = nullptr);

/// Returns the outermost declaration context, other than the translation
/// unit declaration, associated with DC. For example, consider a struct `S`
/// as follows:
///
/// ```
/// namespace A {
/// namespace B {
// struct S {};
/// }
/// }
/// ```
///
/// In this case, outermost declaration context associated with `S` is of
/// namespace `A`.
///
/// \param semaRef
/// \param[in] DC
clang::DeclContext* GetOutermostDC(clang::Sema& semaRef,
clang::DeclContext* DC);

/// Creates a `StringLiteral` node to represent string literal
/// "`str`".
///
Expand Down
18 changes: 3 additions & 15 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,14 @@ namespace clad {

DeclContext* FindDeclContext(clang::Sema& semaRef, clang::DeclContext* DC1,
clang::DeclContext* DC2) {
// llvm::errs()<<"DC1 name: "<<DC1->getDeclKindName()<<"\n";
// llvm::errs()<<"DC2 name: "<<DC2->getDeclKindName()<<"\n";
// cast<Decl>(DC1)->dumpColor();
llvm::SmallVector<clang::DeclContext*, 4> contexts;
assert((isa<NamespaceDecl>(DC1) || isa<TranslationUnitDecl>(DC1)) &&
"DC1 can only be extended if it is a "
"namespace or translation unit decl.");
while (DC2) {
// llvm::errs()<<"DC2 name: "<<DC2->getDeclKindName()<<"\n";
// If somewhere along the way we reach DC1, then we can break the loop.
if (DC2->Equals(DC1))
break;
if (isa<TranslationUnitDecl>(DC2))
break;
if (isa<LinkageSpecDecl>(DC2)) {
Expand Down Expand Up @@ -266,17 +265,6 @@ namespace clad {
return cast<NamespaceDecl>(ND->getPrimaryContext());
}

clang::DeclContext* GetOutermostDC(Sema& semaRef, clang::DeclContext* DC) {
ASTContext& C = semaRef.getASTContext();
assert(DC && "Invalid DC");
while (DC) {
if (DC->getParent() == C.getTranslationUnitDecl())
break;
DC = DC->getParent();
}
return DC;
}

StringLiteral* CreateStringLiteral(ASTContext& C, llvm::StringRef str) {
// Copied and adapted from clang::Sema::ActOnStringLiteral.
QualType CharTyConst = C.CharTy.withConst();
Expand Down
21 changes: 7 additions & 14 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,16 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
// FIXME: Here `if` branch should be removed once we update
// numerical diff to use correct declaration context.
if (forCustomDerv) {
DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC);
// FIXME: We should ideally construct nested name specifier from the
// found custom derivative function. Current way will compute incorrect
// nested name specifier in some cases.
if (outermostDC &&
outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) {
utils::BuildNNS(m_Sema, originalFnDC, SS);
DC = originalFnDC;
} else {
if (isa<RecordDecl>(originalFnDC))
DC = utils::LookupNSD(m_Sema, "class_functions",
/*shouldExist=*/false, NSD);
else
DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC);
if (DC)
utils::BuildNNS(m_Sema, DC, SS);
}
if (isa<RecordDecl>(originalFnDC))
DC = utils::LookupNSD(m_Sema, "class_functions",
/*shouldExist=*/false, NSD);
else
DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC);
if (DC)
utils::BuildNNS(m_Sema, DC, SS);
} else {
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
Expand Down
35 changes: 35 additions & 0 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,38 @@ double f12(double a, double b) { return std::fma(a, b, b); }
//CHECK-NEXT: return _t0.pushforward;
//CHECK-NEXT: }

namespace clad{
namespace custom_derivatives{
clad::ValueAndPushforward<double, double> custom_f13_pushforward(double x, double d_x) {
return {exp(x), exp(x)*d_x};
}
clad::ValueAndPushforward<clad::ValueAndPushforward<double, double>, clad::ValueAndPushforward<double, double> > custom_f13_pushforward_pushforward(double x, double d_x, double _d_x, double _d_d_x) {
return {{exp(x), exp(x)*d_x}, {exp(x)*_d_x, exp(x)*_d_x + exp(x)*_d_d_x}};
}
}
}
double custom_f13(double x) {
return exp(x);
}
double f13(double x) {
return custom_f13(x);
}

//CHECK: double f13_darg0(double x) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = clad::custom_derivatives::custom_f13_pushforward(x, _d_x);
//CHECK-NEXT: return _t0.pushforward;
//CHECK-NEXT: }
//CHECK-NEXT: double f13_d2arg0(double x) {
//CHECK-NEXT: double _d_x = 1;
//CHECK-NEXT: double _d__d_x = 0;
//CHECK-NEXT: double _d_x0 = 1;
//CHECK-NEXT: clad::ValueAndPushforward<clad::ValueAndPushforward<double, double>, clad::ValueAndPushforward<double, double> > _t0 = clad::custom_derivatives::custom_f13_pushforward_pushforward(x, _d_x0, _d_x, _d__d_x);
//CHECK-NEXT: clad::ValueAndPushforward<double, double> _d__t0 = _t0.pushforward;
//CHECK-NEXT: clad::ValueAndPushforward<double, double> _t00 = _t0.value;
//CHECK-NEXT: return _d__t0.pushforward;
//CHECK-NEXT: }

int main () { //expected-no-diagnostics
float f_result[2];
double d_result[2];
Expand Down Expand Up @@ -288,5 +320,8 @@ int main () { //expected-no-diagnostics
auto f12_darg1 = clad::differentiate(f12, 1);
printf("Result is = %f\n", f12_darg1.execute(2, 1)); //CHECK-EXEC: Result is = 3.000000

auto f13_ddx = clad::differentiate<2>(f13);
printf("Result is = %.2f\n", f13_ddx.execute(1)); //CHECK-EXEC: Result is = 2.72

return 0;
}
14 changes: 7 additions & 7 deletions test/NthDerivative/CustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ float test_trig(float x, float y, int a, int b) {
// CHECK-NEXT: int _d_a0 = 0;
// CHECK-NEXT: int _d__d_b = 0;
// CHECK-NEXT: int _d_b0 = 0;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t1 = pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t1 = _t1.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = clad::custom_derivatives::std::cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t2 = _t2.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t20 = _t2.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t3 = pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t3 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t3 = _t3.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t30 = _t3.value;
// CHECK-NEXT: double &_d__t4 = _d__t1.value;
Expand Down Expand Up @@ -89,16 +89,16 @@ float test_trig(float x, float y, int a, int b) {
// CHECK-NEXT: int _d_a0 = 0;
// CHECK-NEXT: int _d__d_b = 0;
// CHECK-NEXT: int _d_b0 = 0;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t0 = clad::custom_derivatives::std::sin_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t0 = _t0.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t00 = _t0.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t1 = pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t1 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t00.value, a, _t00.pushforward, _d_a0, _d__t0.value, _d_a, _d__t0.pushforward, _d__d_a);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t1 = _t1.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t10 = _t1.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<float, float>, ValueAndPushforward<float, float> > _t2 = clad::custom_derivatives::std::cos_pushforward_pushforward(x * y, _d_x0 * y + x * _d_y0, _d_x * y + x * _d_y, _d__d_x * y + _d_x0 * _d_y + _d_x * _d_y0 + x * _d__d_y);
// CHECK-NEXT: ValueAndPushforward<float, float> _d__t2 = _t2.pushforward;
// CHECK-NEXT: ValueAndPushforward<float, float> _t20 = _t2.value;
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t3 = pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: clad::ValueAndPushforward<ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))>, ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> > _t3 = clad::custom_derivatives::std::pow_pushforward_pushforward(_t20.value, b, _t20.pushforward, _d_b0, _d__t2.value, _d_b, _d__t2.pushforward, _d__d_b);
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _d__t3 = _t3.pushforward;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), int())), decltype(::std::pow(float(), int()))> _t30 = _t3.value;
// CHECK-NEXT: double &_d__t4 = _d__t1.value;
Expand Down

0 comments on commit 2cae75f

Please sign in to comment.