Skip to content

Commit

Permalink
Don't create adjoint pullback parameters for non-differentiable argum…
Browse files Browse the repository at this point in the history
…ents.

In most cases, we support non-differentiable variables (i.e. variables that don't have adjoints). Currently, the major application for them is non-independent array parameters. For instance, for
```
double fn17 (double x, double* y) {
    return x;
}
```
a request ``clad::gradient(fn17, "x");`` will produce
```
void fn17_grad_0(double x, double *y, double *_d_x) {
    goto _label0;
  _label0:
    ;
}
```
In this example, ``y`` does not have an adjoint.

However, calling a function of ``y`` produces an error. After these changes, the non-differentiability of ``y`` is propagated to the pullback.

Fixes vgvassilev#765.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Mar 29, 2024
1 parent 2c07477 commit 5b66f0e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
4 changes: 4 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ namespace clad {
}

void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) {
// Diff info for pullbacks is generated automatically,
// its parameters are not provided by the user.
if (Mode == DiffMode::experimental_pullback)
return;
DVI.clear();
auto& C = semaRef.getASTContext();
const Expr* diffArgs = Args;
Expand Down
20 changes: 12 additions & 8 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assert(m_Function && "Must not be null.");

DiffParams args{};
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
#ifndef NDEBUG
bool isStaticMethod = utils::IsStaticMethod(FD);
assert((!args.empty() || !isStaticMethod) &&
Expand Down Expand Up @@ -1509,9 +1512,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statements there later.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

// FIXME: We should add instructions for handling non-differentiable
// arguments. Currently we are implicitly assuming function call only
// contains differentiable arguments.
bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
Expand Down Expand Up @@ -1729,9 +1729,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
"corresponding dfdx().");
}

DerivedCallArgs.insert(DerivedCallArgs.end(),
DerivedCallOutputArgs.begin(),
DerivedCallOutputArgs.end());
for (Expr* arg : DerivedCallOutputArgs)
if (arg)
DerivedCallArgs.push_back(arg);
pullbackCallArgs = DerivedCallArgs;

if (pullback)
Expand Down Expand Up @@ -1782,6 +1782,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = enableTBR;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));
FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
Expand Down
58 changes: 58 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,59 @@ double fn16(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double add(double a, double* b) {
return a + b[0];
}

//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: }

//CHECK: void add_pullback(double a, double *b, double _d_y, double *_d_a, double *_d_b) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: *_d_a += _d_y;
//CHECK-NEXT: _d_b[0] += _d_y;
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn17 (double x, double* y) {
x = add(x, y);
x = add(x, &x);
return x;
}

//CHECK: void fn17_grad_0(double x, double *y, double *_d_x) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = add(x, y);
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: x = add(x, &x);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t1;
//CHECK-NEXT: double _r_d1 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d1;
//CHECK-NEXT: double _r1 = 0;
//CHECK-NEXT: add_pullback(x, &x, _r_d1, &_r1, &*_d_x);
//CHECK-NEXT: *_d_x += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: double _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: double _r0 = 0;
//CHECK-NEXT: add_pullback(x, y, _r_d0, &_r0);
//CHECK-NEXT: *_d_x += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -844,4 +897,9 @@ int main() {
TEST2(fn15, 6, -2) // CHECK-EXEC: {1.00, 1.00}
INIT(fn16);
TEST2(fn16, 12, 8) // CHECK-EXEC: {8.00, 8.00}

auto fn17_grad_0 = clad::gradient(fn17, "x");
double y[] = {3.0, 2.0}, dx = 0;
fn17_grad_0.execute(5, y, &dx);
printf("{%.2f}\n", dx); // CHECK-EXEC: {2.00}
}

0 comments on commit 5b66f0e

Please sign in to comment.