diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0814072d9..c1bbb5a2d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1269,7 +1269,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. if (DRE->getDecl()->getType()->isReferenceType() && - clonedDRE->getType()->isPointerType()) + !VD->getType()->isReferenceType()) clonedDRE = BuildOp(UO_Deref, clonedDRE); if (isVectorValued) { if (m_VectorOutput.size() <= outputArrayCursor) @@ -1583,8 +1583,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (auto* argDerivative : CallArgDx) { Expr* gradArgExpr = nullptr; - const Expr* arg = CE->getArg(idx); - if (utils::isArrayOrPointerType(arg->getType()) || + QualType paramTy = FD->getParamDecl(idx)->getType(); + if (utils::isArrayOrPointerType(paramTy) || isCladArrayType(argDerivative->getType())) gradArgExpr = argDerivative; else @@ -3709,9 +3709,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } clang::QualType ReverseModeVisitor::ComputeParamType(clang::QualType T) { - QualType TValueType = utils::GetValueType(T); - TValueType.removeLocalConst(); - return m_Context.getPointerType(TValueType); + QualType TValueType = utils::GetValueType(T); + return m_Context.getPointerType(TValueType); } llvm::SmallVector diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index b76c46dc3..0a91b4927 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -688,6 +688,30 @@ double fn20(double* x, const double* w) { // CHECK-NEXT: weighted_sum_pullback(x, auxW, 1, _d_x); // CHECK-NEXT: } +double ptrRef(double*& ptr_ref) { + return *ptr_ref; +} + +// CHECK: void ptrRef_pullback(double *&ptr_ref, double _d_y, double **_d_ptr_ref); + +double fn21(double x) { + double* ptr = &x; + return ptrRef(ptr); +} + +// CHECK: void fn21_grad(double x, double *_d_x) { +// CHECK-NEXT: double *_d_ptr = 0; +// CHECK-NEXT: double *_t0; +// CHECK-NEXT: _d_ptr = &*_d_x; +// CHECK-NEXT: double *ptr = &x; +// CHECK-NEXT: _t0 = ptr; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: ptr = _t0; +// CHECK-NEXT: ptrRef_pullback(_t0, 1, &_d_ptr); +// CHECK-NEXT: } +// CHECK-NEXT: } template void reset(T* arr, int n) { @@ -792,6 +816,9 @@ int main() { double dx1[] = {0.0, 0.0}; fn20_grad_0.execute(x1, w1, dx1); printf("{%.2f, %.2f}\n", dx1[0], dx1[1]); // CHECK-EXEC: {2.00, 3.00} + + INIT(fn21); + TEST1(fn21, 8); // CHECK-EXEC: {1.00} } double sq_defined_later(double x) { @@ -1043,4 +1070,10 @@ double sq_defined_later(double x) { // CHECK-NEXT: _d_x[0] += w[0] * _d_y; // CHECK-NEXT: _d_x[1] += w[1] * _d_y; // CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void ptrRef_pullback(double *&ptr_ref, double _d_y, double **_d_ptr_ref) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: **_d_ptr_ref += _d_y; // CHECK-NEXT: } \ No newline at end of file