Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable some cases of functor calls in custom pushforwards #1038

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class BaseForwardModeVisitor

virtual void ExecuteInsidePushforwardFunctionBlock();

virtual void DifferentiateCallOperatorIfFunctor(clang::QualType QT);

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
Expand Down
60 changes: 60 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,8 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
}

StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DifferentiateCallOperatorIfFunctor(DRE->getType());

DeclRefExpr* clonedDRE = nullptr;
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
Expand Down Expand Up @@ -1594,6 +1596,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
// If the DeclStmt is not empty, check the first declaration.
if (declsBegin != declsEnd && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
DifferentiateCallOperatorIfFunctor(VD->getType());
// Check for non-differentiable types.
QualType QT = VD->getType();
if (QT->isPointerType())
Expand Down Expand Up @@ -2057,8 +2060,65 @@ StmtDiff BaseForwardModeVisitor::VisitBreakStmt(const BreakStmt* stmt) {
return StmtDiff(Clone(stmt));
}

void BaseForwardModeVisitor::DifferentiateCallOperatorIfFunctor(
clang::QualType QT) {
// Identify if the constructed type is a functor. For functors, we need to
// differentiate their call operator once an object has been constructed, to
// allow user calls to pushforwards inside user-provided custom derivatives.
// FIXME: A much more scalable solution would be to create pushforwards once
// they're called from user-provided custom derivatives. This could then be
// applied to other operators besides operator() to avoid compilation errors
// in such cases.
if (auto* RD = QT->getAsCXXRecordDecl()) {
CXXRecordDecl* constructedType = RD->getDefinition();
bool isFunctor = constructedType && !constructedType->isLambda();
std::vector<const CXXMethodDecl*> callMethods;
if (isFunctor) {
for (const auto* method : constructedType->methods()) {
if (const auto* cxxMethod = dyn_cast<CXXMethodDecl>(method)) {
if (cxxMethod->isOverloadedOperator() &&
cxxMethod->getOverloadedOperator() == OO_Call) {
callMethods.push_back(cxxMethod);
}
}
}
isFunctor = isFunctor && !callMethods.empty();
}

if (isFunctor) {
for (const auto* FD : callMethods) {
CXXScopeSpec SS;
bool hasCustomDerivative =
!m_Builder
.LookupCustomDerivativeOrNumericalDiff(
clad::utils::ComputeEffectiveFnName(FD) +
GetPushForwardFunctionSuffix(),
const_cast<DeclContext*>(FD->getDeclContext()), SS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]

                     const_cast<DeclContext*>(FD->getDeclContext()), SS)
                     ^

.empty();

if (!hasCustomDerivative) {
// Request Clad to diff it.
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName =
utils::ComputeEffectiveFnName(FD);
// Silence diag outputs in nested derivation process.
pushforwardFnRequest.VerboseDiags = false;

// Check if request already derived in DerivedFunctions.
m_Builder.HandleNestedDiffRequest(pushforwardFnRequest);
}
}
}
}
}

StmtDiff
BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
DifferentiateCallOperatorIfFunctor(CE->getType());

// Now continue differentiating the constructor itself:
llvm::SmallVector<Expr*, 4> clonedArgs, derivedArgs;
for (auto arg : CE->arguments()) {
auto argDiff = Visit(arg);
Expand Down
47 changes: 47 additions & 0 deletions test/ForwardMode/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,36 @@ struct WidgetPointer {
}
};

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double x, F& f) {
f(x);
}

struct Foo {
double &y;
Foo(double &y): y(y) {}

double operator()(double x) {
y = 2*x;

return x;
}
};

double fn0(double x) {
Foo func = Foo{x};
use_functor(x, func);
return x;
}

#define INIT(E, ARG)\
auto d_##E = clad::differentiate(&E, ARG);\
auto d_##E##Ref = clad::differentiate(E, ARG);
Expand Down Expand Up @@ -504,4 +534,21 @@ int main() {
TEST_2(W_Arr_5, 6, 5); // CHECK-EXEC: 6.00 6.00
TEST_2(W_Pointer_3, 6, 5); // CHECK-EXEC: 37.00 37.00
TEST_2(W_Pointer_5, 6, 5); // CHECK-EXEC: 51.00 51.00

auto dfn0 = clad::differentiate(fn0, "x");
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2
}

// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x);
// CHECK: double fn0_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: Foo _d_func = Foo{{[{]*_d_x[}]*}};
// CHECK-NEXT: Foo func = Foo{{[{]*x[}]*}};
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT:}
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) {
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x;
// CHECK-NEXT: this->y = 2 * x;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT:}
31 changes: 29 additions & 2 deletions test/ForwardMode/ReferenceArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@

#include "clad/Differentiator/Differentiator.h"

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double &x, F& f, double &d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double &x, F& f) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal is to differentiate use_functor, then can we provide custom derivative operator_call_pushforward(Functor *F, double &x, Functor *d_F, double d_x) and Clad can itself differentiate use_functor<Foo> instantiation? This is more general than providing a custom derivative for use_functor<Foo> and can be helpful if there are multiple functions using Foo functor.

f(x);
}

struct Foo {
double operator()(double& x) {
x = 2*x*x;
return x;
}
};

double fn0(double x, Foo& func) {
use_functor(x, func);
return x;
}

double fn1(double& i, double& j) {
double res = i * i * j;
return res;
Expand All @@ -21,12 +46,14 @@ double fn1(double& i, double& j) {
#define INIT(fn, ...) auto d_##fn = clad::differentiate(fn, __VA_ARGS__);

#define TEST(fn, ...) \
auto res = d_##fn.execute(__VA_ARGS__); \
printf("{%.2f}\n", res)
printf("{%.2f}\n", d_##fn.execute(__VA_ARGS__))

int main() {
INIT(fn0, "x");
INIT(fn1, "i");

double i = 3, j = 5;
TEST(fn1, i, j); // CHECK-EXEC: {30.00}
Foo fff;
TEST(fn0, i, fff); // CHECK-EXEC: {12.00}
}
24 changes: 12 additions & 12 deletions test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ Tensor<double, 5> fn5(double i, double j) {
return T;
}

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val);

// CHECK: Tensor<double, 5> fn5_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
Expand Down Expand Up @@ -593,8 +595,6 @@ TensorD5 fn11(double i, double j) {
return res1;
}

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val);

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx);

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > operator_plus_pushforward(const Tensor<double, 5U> &a, const Tensor<double, 5U> &b, const Tensor<double, 5U> &_d_a, const Tensor<double, 5U> &_d_b);
Expand Down Expand Up @@ -965,6 +965,16 @@ double fn18(double i, double j) {
// CHECK-NEXT: return _d_v[0].mem;
// CHECK-NEXT: }

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val) {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: _d_this->data[i] = _d_val;
// CHECK-NEXT: this->data[i] = val;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

template<unsigned N>
void print(const Tensor<double, N>& t) {
for (int i=0; i<N; ++i) {
Expand Down Expand Up @@ -1071,16 +1081,6 @@ int main() {
// CHECK-NEXT: return {{[{](__imag )?}}this->[[_M_value:[a-zA-Z_]+]],{{( __imag)?}} _d_this->[[_M_value:[a-zA-Z_]+]]};
// CHECK-NEXT: }

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val) {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: _d_this->data[i] = _d_val;
// CHECK-NEXT: this->data[i] = val;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx) {
// CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]};
// CHECK-NEXT: }
Expand Down
47 changes: 47 additions & 0 deletions test/Functors/Simple.C
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,50 @@ float f(float x) {
return x;
}

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should not use clad-generated derivatives inside custom derivatives. This will make the custom-derivatives less portable because the compilation will fail for the cases where Clad has not generated the derivative.

Using clad-generated member function derivatives in custom derivatives can only work for template functions. This makes the functionality inconsistent. For example:

#include "clad/Differentiator/Differentiator.h"
#include <iostream>

struct Foo {
    double &y;
    Foo(double &y): y(y) {} 

    double operator()(double x) {
        y = 2*x;
        return x;
    }
};

namespace clad {
    namespace custom_derivatives {
        void use_functor_pushforward(double x, Foo &f, double d_x, Foo &d_f) {
            f.operator_call_pushforward(x, &d_f, d_x);
        }
    }
}

void use_functor(double x, Foo &f) {
    f(x);
}


double fn(double x) {
    Foo func = Foo(x);
    use_functor(x, func);
    return x;
}

int main() {
    auto fn_grad = clad::differentiate(fn);
}

Compilation of above fails with the error message:

FunctorCustomDerv.cpp:17:15: error: no member named 'operator_call_pushforward' in 'Foo'
            f.operator_call_pushforward(x, &d_f, d_x);

Things work in template function because of SFINAE rule and delayed instantiations of templates.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I need this template situation to be solved to support Kokkos in particular, and I think such a pattern would be relatively common when supporting other similar frameworks in Clad (or when the users try to do that). as I've mentioned, this is a blocker there. I agree that this PR only addresses things that occur when using templates and I'm okay with trying to generalise this. I don't really advocate for this approach specifically, but what I need here is that:

  1. it should be possible to provide custom pushforwards for functions (like use_functor in this example) that the user can call with their data types without defining anything else themselves. so we need to be able to provide general template pushforwards. this way we could support those frameworks and not demand anything from the user, which is a requirement (the users should be able to use this in their existing applications at least theoretically and shouldn't modify the code of their classes to use them with Kokkos+Clad). the first thing I thought about was to just track the usage of undefined functions or methods with _pushforward in their name and then tell clad to put that into the differentiation plan properly, but I'm not sure how that can be implemented and if that doesn't introduce even more inconsistencies. Vassil has suggested that we can solve this in a way similar to how this PR works (just differentiate call operators in functors once they're used), so I did it like that. I should probably move these actions from the visitor to the planner though, that's my bad.

  2. whatever the approach to solving this issue would be, I need it to work for both functors AND lambdas once we get the lambda support in Clad. what I mean is that the custom pushforward template function for the use_functor routine in this example should also work for lambdas. for lambdas, we'd differentiate their call methods as soon as the lambda is created anyway (as far as I imagine that right now), and we cannot rely on custom derivatives for the call methods at all, since there's kind of no way to provide these for lambdas from the user side, and I'm not sure how that would need to work even.

so I think what these two requirements imply is that the usage of generated pushforwards in custom derivatives is unavoidable in supporting some frameworks that operate with functors and lambdas like the Kokkos does (at least from what I see). but I'd be glad to use a more general or prettier approach here. can we hop on a call to resolve this maybe?

}
}
}
template <typename F>
void use_functor(double x, F& f) {
f(x);
}

struct Foo {
double &y;
Foo(double &y): y(y) {}

double operator()(double x) {
y = 2*x;

return x;
}
};

double fn0(double x) {
Foo func = Foo({x});
use_functor(x, func);
return x;
}

// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x);
// CHECK: double fn0_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: Foo _d_func = Foo({_d_x});
// CHECK-NEXT: Foo func = Foo({x});
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT:}
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) {
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x;
// CHECK-NEXT: this->y = 2 * x;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT:}

int main() {
AFunctor doubler;
int x = doubler(5);
Expand All @@ -73,5 +117,8 @@ int main() {
auto f1_darg1 = clad::differentiate(&SimpleExpression::operator(), 1);
printf("Result is = %f\n", f1_darg1.execute(expr, 3.5, 4.5)); // CHECK-EXEC: Result is = 9

auto dfn0 = clad::differentiate(fn0, "x");
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2

return 0;
}
Loading