Skip to content

Commit

Permalink
Enable some cases of functor calls in custom pushforwards
Browse files Browse the repository at this point in the history
Previously, if a user wanted to provide a custom pushforward for a
function that uses functors in it, it was impossible to use a generated
pushforwards for that functors' call operators. This commit aims to fix
for basic functors that don't have multiple call operator overloads.

Fixes: #1023
  • Loading branch information
gojakuch committed Aug 13, 2024
1 parent 6cc83ee commit b85d631
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 14 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class BaseForwardModeVisitor

virtual void ExecuteInsidePushforwardFunctionBlock();

virtual void DifferentiateCallOperatorIfFunctor(clang::QualType QT);

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
Expand Down
60 changes: 60 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,8 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
}

StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DifferentiateCallOperatorIfFunctor(DRE->getType());

DeclRefExpr* clonedDRE = nullptr;
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
Expand Down Expand Up @@ -1594,6 +1596,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
// If the DeclStmt is not empty, check the first declaration.
if (declsBegin != declsEnd && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
DifferentiateCallOperatorIfFunctor(VD->getType());
// Check for non-differentiable types.
QualType QT = VD->getType();
if (QT->isPointerType())
Expand Down Expand Up @@ -2057,8 +2060,65 @@ StmtDiff BaseForwardModeVisitor::VisitBreakStmt(const BreakStmt* stmt) {
return StmtDiff(Clone(stmt));
}

void BaseForwardModeVisitor::DifferentiateCallOperatorIfFunctor(
clang::QualType QT) {
// Identify if the constructed type is a functor. For functors, we need to
// differentiate their call operator once an object has been constructed, to
// allow user calls to pushforwards inside user-provided custom derivatives.
// FIXME: A much more scalable solution would be to create pushforwards once
// they're called from user-provided custom derivatives. This could then be
// applied to other operators besides operator() to avoid compilation errors
// in such cases.
if (auto RD = QT->getAsCXXRecordDecl()) {
CXXRecordDecl* constructedType = RD->getDefinition();
bool isFunctor = constructedType && !constructedType->isLambda();
std::vector<const CXXMethodDecl*> callMethods;
if (isFunctor) {
for (const auto* method : constructedType->methods()) {
if (const auto* cxxMethod = dyn_cast<CXXMethodDecl>(method)) {
if (cxxMethod->isOverloadedOperator() &&
cxxMethod->getOverloadedOperator() == OO_Call) {
callMethods.push_back(cxxMethod);
}
}
}
isFunctor = isFunctor && !callMethods.empty();
}

if (isFunctor) {
for (const auto* FD : callMethods) {
CXXScopeSpec SS;
bool hasCustomDerivative =
!m_Builder
.LookupCustomDerivativeOrNumericalDiff(
clad::utils::ComputeEffectiveFnName(FD) +
GetPushForwardFunctionSuffix(),
const_cast<DeclContext*>(FD->getDeclContext()), SS)
.empty();

if (!hasCustomDerivative) {
// Request Clad to diff it.
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName =
utils::ComputeEffectiveFnName(FD);
// Silence diag outputs in nested derivation process.
pushforwardFnRequest.VerboseDiags = false;

// Check if request already derived in DerivedFunctions.
m_Builder.HandleNestedDiffRequest(pushforwardFnRequest);
}
}
}
}
}

StmtDiff
BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
DifferentiateCallOperatorIfFunctor(CE->getType());

// Now continue differentiating the constructor itself:
llvm::SmallVector<Expr*, 4> clonedArgs, derivedArgs;
for (auto arg : CE->arguments()) {
auto argDiff = Visit(arg);
Expand Down
47 changes: 47 additions & 0 deletions test/ForwardMode/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,36 @@ struct WidgetPointer {
}
};

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double x, F& f) {
f(x);
}

struct Foo {
double &y;
Foo(double &y): y(y) {}

double operator()(double x) {
y = 2*x;

return x;
}
};

double fn0(double x) {
Foo func = Foo{x};
use_functor(x, func);
return x;
}

#define INIT(E, ARG)\
auto d_##E = clad::differentiate(&E, ARG);\
auto d_##E##Ref = clad::differentiate(E, ARG);
Expand Down Expand Up @@ -504,4 +534,21 @@ int main() {
TEST_2(W_Arr_5, 6, 5); // CHECK-EXEC: 6.00 6.00
TEST_2(W_Pointer_3, 6, 5); // CHECK-EXEC: 37.00 37.00
TEST_2(W_Pointer_5, 6, 5); // CHECK-EXEC: 51.00 51.00

auto dfn0 = clad::differentiate(fn0, "x");
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2
}

// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x);
// CHECK: double fn0_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: Foo _d_func = Foo{{[{]*_d_x[}]*}};
// CHECK-NEXT: Foo func = Foo{{[{]*x[}]*}};
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT:}
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) {
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x;
// CHECK-NEXT: this->y = 2 * x;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT:}
31 changes: 29 additions & 2 deletions test/ForwardMode/ReferenceArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@

#include "clad/Differentiator/Differentiator.h"

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double &x, F& f, double &d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double &x, F& f) {
f(x);
}

struct Foo {
double operator()(double& x) {
x = 2*x*x;
return x;
}
};

double fn0(double x, Foo& func) {
use_functor(x, func);
return x;
}

double fn1(double& i, double& j) {
double res = i * i * j;
return res;
Expand All @@ -21,12 +46,14 @@ double fn1(double& i, double& j) {
#define INIT(fn, ...) auto d_##fn = clad::differentiate(fn, __VA_ARGS__);

#define TEST(fn, ...) \
auto res = d_##fn.execute(__VA_ARGS__); \
printf("{%.2f}\n", res)
printf("{%.2f}\n", d_##fn.execute(__VA_ARGS__))

int main() {
INIT(fn0, "x");
INIT(fn1, "i");

double i = 3, j = 5;
TEST(fn1, i, j); // CHECK-EXEC: {30.00}
Foo fff;
TEST(fn0, i, fff); // CHECK-EXEC: {12.00}
}
24 changes: 12 additions & 12 deletions test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ Tensor<double, 5> fn5(double i, double j) {
return T;
}

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val);

// CHECK: Tensor<double, 5> fn5_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
Expand Down Expand Up @@ -593,8 +595,6 @@ TensorD5 fn11(double i, double j) {
return res1;
}

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val);

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx);

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > operator_plus_pushforward(const Tensor<double, 5U> &a, const Tensor<double, 5U> &b, const Tensor<double, 5U> &_d_a, const Tensor<double, 5U> &_d_b);
Expand Down Expand Up @@ -965,6 +965,16 @@ double fn18(double i, double j) {
// CHECK-NEXT: return _d_v[0].mem;
// CHECK-NEXT: }

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val) {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: _d_this->data[i] = _d_val;
// CHECK-NEXT: this->data[i] = val;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

template<unsigned N>
void print(const Tensor<double, N>& t) {
for (int i=0; i<N; ++i) {
Expand Down Expand Up @@ -1071,16 +1081,6 @@ int main() {
// CHECK-NEXT: return {{[{](__imag )?}}this->[[_M_value:[a-zA-Z_]+]],{{( __imag)?}} _d_this->[[_M_value:[a-zA-Z_]+]]};
// CHECK-NEXT: }

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val) {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: _d_this->data[i] = _d_val;
// CHECK-NEXT: this->data[i] = val;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx) {
// CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]};
// CHECK-NEXT: }
Expand Down
47 changes: 47 additions & 0 deletions test/Functors/Simple.C
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,50 @@ float f(float x) {
return x;
}

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double x, F& f) {
f(x);
}

struct Foo {
double &y;
Foo(double &y): y(y) {}

double operator()(double x) {
y = 2*x;

return x;
}
};

double fn0(double x) {
Foo func = Foo({x});
use_functor(x, func);
return x;
}

// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x);
// CHECK: double fn0_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: Foo _d_func = Foo({_d_x});
// CHECK-NEXT: Foo func = Foo({x});
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT:}
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) {
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x;
// CHECK-NEXT: this->y = 2 * x;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT:}

int main() {
AFunctor doubler;
int x = doubler(5);
Expand All @@ -73,5 +117,8 @@ int main() {
auto f1_darg1 = clad::differentiate(&SimpleExpression::operator(), 1);
printf("Result is = %f\n", f1_darg1.execute(expr, 3.5, 4.5)); // CHECK-EXEC: Result is = 9

auto dfn0 = clad::differentiate(fn0, "x");
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2

return 0;
}

0 comments on commit b85d631

Please sign in to comment.