-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable some cases of functor calls in custom pushforwards #1038
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,31 @@ | |
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
namespace clad { | ||
namespace custom_derivatives { | ||
template <typename F> | ||
void use_functor_pushforward(double &x, F& f, double &d_x, F& d_f) { | ||
f.operator_call_pushforward(x, &d_f, d_x); | ||
} | ||
} | ||
} | ||
template <typename F> | ||
void use_functor(double &x, F& f) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the goal is to differentiate |
||
f(x); | ||
} | ||
|
||
struct Foo { | ||
double operator()(double& x) { | ||
x = 2*x*x; | ||
return x; | ||
} | ||
}; | ||
|
||
double fn0(double x, Foo& func) { | ||
use_functor(x, func); | ||
return x; | ||
} | ||
|
||
double fn1(double& i, double& j) { | ||
double res = i * i * j; | ||
return res; | ||
|
@@ -21,12 +46,14 @@ double fn1(double& i, double& j) { | |
#define INIT(fn, ...) auto d_##fn = clad::differentiate(fn, __VA_ARGS__); | ||
|
||
#define TEST(fn, ...) \ | ||
auto res = d_##fn.execute(__VA_ARGS__); \ | ||
printf("{%.2f}\n", res) | ||
printf("{%.2f}\n", d_##fn.execute(__VA_ARGS__)) | ||
|
||
int main() { | ||
INIT(fn0, "x"); | ||
INIT(fn1, "i"); | ||
|
||
double i = 3, j = 5; | ||
TEST(fn1, i, j); // CHECK-EXEC: {30.00} | ||
Foo fff; | ||
TEST(fn0, i, fff); // CHECK-EXEC: {12.00} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,50 @@ float f(float x) { | |
return x; | ||
} | ||
|
||
namespace clad { | ||
namespace custom_derivatives { | ||
template <typename F> | ||
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) { | ||
f.operator_call_pushforward(x, &d_f, d_x); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we should not use clad-generated derivatives inside custom derivatives. This will make the custom-derivatives less portable because the compilation will fail for the cases where Clad has not generated the derivative. Using clad-generated member function derivatives in custom derivatives can only work for template functions. This makes the functionality inconsistent. For example: #include "clad/Differentiator/Differentiator.h"
#include <iostream>
struct Foo {
double &y;
Foo(double &y): y(y) {}
double operator()(double x) {
y = 2*x;
return x;
}
};
namespace clad {
namespace custom_derivatives {
void use_functor_pushforward(double x, Foo &f, double d_x, Foo &d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
void use_functor(double x, Foo &f) {
f(x);
}
double fn(double x) {
Foo func = Foo(x);
use_functor(x, func);
return x;
}
int main() {
auto fn_grad = clad::differentiate(fn);
} Compilation of above fails with the error message:
Things work in template function because of SFINAE rule and delayed instantiations of templates. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, I need this template situation to be solved to support Kokkos in particular, and I think such a pattern would be relatively common when supporting other similar frameworks in Clad (or when the users try to do that). as I've mentioned, this is a blocker there. I agree that this PR only addresses things that occur when using templates and I'm okay with trying to generalise this. I don't really advocate for this approach specifically, but what I need here is that:
so I think what these two requirements imply is that the usage of generated pushforwards in custom derivatives is unavoidable in supporting some frameworks that operate with functors and lambdas like the Kokkos does (at least from what I see). but I'd be glad to use a more general or prettier approach here. can we hop on a call to resolve this maybe? |
||
} | ||
} | ||
} | ||
template <typename F> | ||
void use_functor(double x, F& f) { | ||
f(x); | ||
} | ||
|
||
struct Foo { | ||
double &y; | ||
Foo(double &y): y(y) {} | ||
|
||
double operator()(double x) { | ||
y = 2*x; | ||
|
||
return x; | ||
} | ||
}; | ||
|
||
double fn0(double x) { | ||
Foo func = Foo({x}); | ||
use_functor(x, func); | ||
return x; | ||
} | ||
|
||
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x); | ||
// CHECK: double fn0_darg0(double x) { | ||
// CHECK-NEXT: double _d_x = 1; | ||
// CHECK-NEXT: Foo _d_func = Foo({_d_x}); | ||
// CHECK-NEXT: Foo func = Foo({x}); | ||
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func); | ||
// CHECK-NEXT: return _d_x; | ||
// CHECK-NEXT:} | ||
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) { | ||
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x; | ||
// CHECK-NEXT: this->y = 2 * x; | ||
// CHECK-NEXT: return {x, _d_x}; | ||
// CHECK-NEXT:} | ||
|
||
int main() { | ||
AFunctor doubler; | ||
int x = doubler(5); | ||
|
@@ -73,5 +117,8 @@ int main() { | |
auto f1_darg1 = clad::differentiate(&SimpleExpression::operator(), 1); | ||
printf("Result is = %f\n", f1_darg1.execute(expr, 3.5, 4.5)); // CHECK-EXEC: Result is = 9 | ||
|
||
auto dfn0 = clad::differentiate(fn0, "x"); | ||
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2 | ||
|
||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]