Skip to content

Commit

Permalink
Support multiple indices in clad::gradient calls.
Browse files Browse the repository at this point in the history
Currently, the user can provide a string as the second argument of ``clad::gradient`` to specify independent parameters as a list of comma-separated names. This commit allows users to specify indices alongside with names. e.g.
```
clad::gradient(fn, "0");
clad::gradient(fn, "1, z");
...
```
Previously, it was possible to provide a single index as an integer literal. e.g.
```
clad::gradient(fn, 0);
```
Fixes vgvassilev#46.
  • Loading branch information
PetroZarytskyi committed Aug 5, 2024
1 parent 7cec7c8 commit 0f4fae0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ Reverse-mode AD allows computing the gradient of `f` using *at most* a constant
1. `f` is a pointer to a function or a method to be differentiated
2. `ARGS` is either:
* not provided, then `f` is differentiated w.r.t. its every argument
* a string literal with comma-separated names of independent variables (e.g. `"x"` or `"y"` or `"x, y"` or `"y, x"`)
* a string literal with comma-separated names/indices of independent variables (e.g. `"x"`, `"y"`, `"x, y"`, `"y, x"`, "0, 1", "0, y", etc.)
* a SINGLE number representing the index of the independent variable
Since a vector of derivatives must be returned from a function generated by the reverse mode, its signature is slightly different. The generated function has `void` return type and same input arguments. The function has additional `n` arguments (where `n` refers to the number of arguments whose gradient was requested) of type `T*`, where `T` is the type of the corresponding original variable. Each of these variables stores the derivative of the elements as they appear in the orignal function signature. *The caller is responsible for allocating and zeroing-out the gradient storage*. Example:
```cpp
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,21 @@ namespace clad {
DiffInputVarInfo dVarInfo;

dVarInfo.source = diffSpec.str();
// Check if diffSpec represents an index of an independent variable.
if ('0' <= diffSpec[0] && diffSpec[0] <= '9') {
unsigned idx = std::stoi(dVarInfo.source);
// Fail if the specified index is invalid.
if (idx >= FD->getNumParams()) {
utils::EmitDiag(
semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(),
"Invalid argument index '%0' of '%1' argument(s)",
{std::to_string(idx), std::to_string(FD->getNumParams())});
return;
}
dVarInfo.param = FD->getParamDecl(idx);
DVI.push_back(dVarInfo);
continue;
}
llvm::StringRef pName = computeParamName(diffSpec);
auto it = std::find_if(std::begin(candidates), std::end(candidates),
[&pName](
Expand Down
8 changes: 5 additions & 3 deletions test/FirstDerivative/DiffInterface.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | %filecheck %s
// RUN: %cladclang -ferror-limit=100 %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | %filecheck %s

#include "clad/Differentiator/Differentiator.h"

Expand Down Expand Up @@ -131,8 +131,6 @@ int main () {

clad::differentiate(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}}

clad::differentiate(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}}

clad::differentiate(f_2, 3); // expected-error {{Invalid argument index '3' of '3' argument(s)}}

clad::differentiate(f_2, 9); // expected-error {{Invalid argument index '9' of '3' argument(s)}}
Expand All @@ -141,6 +139,10 @@ int main () {

clad::differentiate(f_2, f_2); // expected-error {{Failed to parse the parameters, must be a string or numeric literal}}

clad::gradient(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}}

clad::gradient(f_2, "9"); // expected-error {{Invalid argument index '9' of '3' argument(s)}}

clad::differentiate(f_3, 0); // expected-error {{Invalid argument index '0' of '0' argument(s)}}

float one = 1.0;
Expand Down
10 changes: 10 additions & 0 deletions test/Gradient/DiffInterface.C
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,22 @@ int main () {

auto f1_grad_y = clad::gradient(f_1, "y");
TEST(f1_grad_y, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_0 = clad::gradient(f_1, "1");
TEST(f1_grad_0, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_z = clad::gradient(f_1, "z");
TEST(f1_grad_z, &result[2]); // CHECK-EXEC: {0.00, 0.00, 2.00}

auto f1_grad_xy = clad::gradient(f_1, "x, y");
TEST(f1_grad_xy, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_0y = clad::gradient(f_1, "0, y");
TEST(f1_grad_0y, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_10 = clad::gradient(f_1, "1, 0");
TEST(f1_grad_10, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_yx = clad::gradient(f_1, "y, x");
TEST(f1_grad_yx, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

Expand Down

0 comments on commit 0f4fae0

Please sign in to comment.