Skip to content

Commit

Permalink
Correctly handle pointer reference parameters in the reverse mode (vg…
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored Jun 8, 2024
1 parent 7d330c3 commit 56a1879
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
11 changes: 5 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// global. Ref-type declarations cannot be moved to the function global
// scope because they can't be separated from their inits.
if (DRE->getDecl()->getType()->isReferenceType() &&
clonedDRE->getType()->isPointerType())
!VD->getType()->isReferenceType())
clonedDRE = BuildOp(UO_Deref, clonedDRE);
if (isVectorValued) {
if (m_VectorOutput.size() <= outputArrayCursor)
Expand Down Expand Up @@ -1583,8 +1583,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

for (auto* argDerivative : CallArgDx) {
Expr* gradArgExpr = nullptr;
const Expr* arg = CE->getArg(idx);
if (utils::isArrayOrPointerType(arg->getType()) ||
QualType paramTy = FD->getParamDecl(idx)->getType();
if (utils::isArrayOrPointerType(paramTy) ||
isCladArrayType(argDerivative->getType()))
gradArgExpr = argDerivative;
else
Expand Down Expand Up @@ -3709,9 +3709,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

clang::QualType ReverseModeVisitor::ComputeParamType(clang::QualType T) {
QualType TValueType = utils::GetValueType(T);
TValueType.removeLocalConst();
return m_Context.getPointerType(TValueType);
QualType TValueType = utils::GetValueType(T);
return m_Context.getPointerType(TValueType);
}

llvm::SmallVector<clang::QualType, 8>
Expand Down
33 changes: 33 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,30 @@ double fn20(double* x, const double* w) {
// CHECK-NEXT: weighted_sum_pullback(x, auxW, 1, _d_x);
// CHECK-NEXT: }

double ptrRef(double*& ptr_ref) {
return *ptr_ref;
}

// CHECK: void ptrRef_pullback(double *&ptr_ref, double _d_y, double **_d_ptr_ref);

double fn21(double x) {
double* ptr = &x;
return ptrRef(ptr);
}

// CHECK: void fn21_grad(double x, double *_d_x) {
// CHECK-NEXT: double *_d_ptr = 0;
// CHECK-NEXT: double *_t0;
// CHECK-NEXT: _d_ptr = &*_d_x;
// CHECK-NEXT: double *ptr = &x;
// CHECK-NEXT: _t0 = ptr;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: ptr = _t0;
// CHECK-NEXT: ptrRef_pullback(_t0, 1, &_d_ptr);
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
Expand Down Expand Up @@ -792,6 +816,9 @@ int main() {
double dx1[] = {0.0, 0.0};
fn20_grad_0.execute(x1, w1, dx1);
printf("{%.2f, %.2f}\n", dx1[0], dx1[1]); // CHECK-EXEC: {2.00, 3.00}

INIT(fn21);
TEST1(fn21, 8); // CHECK-EXEC: {1.00}
}

double sq_defined_later(double x) {
Expand Down Expand Up @@ -1043,4 +1070,10 @@ double sq_defined_later(double x) {
// CHECK-NEXT: _d_x[0] += w[0] * _d_y;
// CHECK-NEXT: _d_x[1] += w[1] * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void ptrRef_pullback(double *&ptr_ref, double _d_y, double **_d_ptr_ref) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: **_d_ptr_ref += _d_y;
// CHECK-NEXT: }

0 comments on commit 56a1879

Please sign in to comment.