Skip to content

Commit

Permalink
Keep adjoint types same as the original ones on function global scope
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jul 30, 2024
1 parent 10b325a commit e326768
Show file tree
Hide file tree
Showing 29 changed files with 270 additions and 273 deletions.
12 changes: 9 additions & 3 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ namespace clad {
/// type.
static clang::QualType
getNonConstType(clang::QualType T, clang::ASTContext& C, clang::Sema& S) {
clang::Qualifiers quals(T.getQualifiers());
quals.removeConst();
return S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals);
bool isLValueRefType = T->isLValueReferenceType();
T = T.getNonReferenceType();
clang::Qualifiers quals(T.getQualifiers());
quals.removeConst();
clang::QualType nonConstType =
S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals);
if (isLValueRefType)
return C.getLValueReferenceType(nonConstType);
return nonConstType;
}
// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();
Expand Down
45 changes: 29 additions & 16 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2564,15 +2564,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool promoteToFnScope =
!getCurrentScope()->isFunctionScope() &&
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass;
QualType VDCloneType = CloneType(VD->getType());
QualType VDDerivedType = ComputeAdjointType(VDCloneType);
QualType VDCloneType;
QualType VDDerivedType;
// If the cloned declaration is moved to the function global scope,
// change its type for the corresponding adjoint type.
if (promoteToFnScope) {
VDDerivedType = ComputeAdjointType(CloneType(VD->getType()));
VDCloneType = VDDerivedType;
if (isa<ArrayType>(VDCloneType) && !isa<IncompleteArrayType>(VDCloneType))
VDCloneType =
GetCladArrayOfType(m_Context.getBaseElementType(VDCloneType));
} else {
VDCloneType = CloneType(VD->getType());
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
Expand Down Expand Up @@ -2633,7 +2637,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ComputeAdjointType(VD->getType().getNonReferenceType());
isDerivativeOfRefType = false;
}
VDDerivedInit = getZeroInit(VDDerivedType);
if (promoteToFnScope || !isDerivativeOfRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
Expand Down Expand Up @@ -2735,7 +2742,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType) {
if (isDerivativeOfRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand Down Expand Up @@ -2766,17 +2773,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
if (isPointerType && derivedVDE) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
BuildDeclStmt(reverseSweepDerivativePointerE));
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
BuildDeclStmt(reverseSweepDerivativePointerE));
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
} else {
VDDerived->setInit(initDiff.getExpr_dx());
}
}
if (derivedVDE)
Expand Down Expand Up @@ -2984,7 +2995,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
if (!declsDiff.empty()) {
Stmt* DSDiff = BuildDeclStmt(declsDiff);
addToBlock(DSDiff, m_Globals);
Stmts& block =
promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward);
addToBlock(DSDiff, block);
}

if (m_ExternalSource) {
Expand Down
26 changes: 13 additions & 13 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ float func(float* a, float* b) {
}

//CHECK: void func_grad(float *a, float *b, float *_d_a, float *_d_b) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
Expand Down Expand Up @@ -93,11 +93,11 @@ float func2(float* a) {
}

//CHECK: void func2_grad(float *a, float *_d_a) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
Expand Down Expand Up @@ -132,12 +132,12 @@ float func3(float* a, float* b) {
}

//CHECK: void func3_grad(float *a, float *b, float *_d_a, float *_d_b) {
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: float _d_sum = 0;
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
Expand Down Expand Up @@ -176,13 +176,13 @@ double func4(double x) {
}

//CHECK: void func4_grad(double x, double *_d_x) {
//CHECK-NEXT: double _d_arr[3] = {0};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_arr[3] = {0};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
Expand Down Expand Up @@ -234,16 +234,15 @@ double func5(int k) {
}

//CHECK: void func5_grad(int k, int *_d_k) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t2;
//CHECK-NEXT: int _d_i0 = 0;
//CHECK-NEXT: int i0 = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: int n = k;
//CHECK-NEXT: double _d_arr[n];
//CHECK-NEXT: clad::zero_init(_d_arr, n);
Expand All @@ -258,6 +257,7 @@ double func5(int k) {
//CHECK-NEXT: clad::push(_t1, arr[i]);
//CHECK-NEXT: arr[i] = k;
//CHECK-NEXT: }
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = {{0U|0UL}};
//CHECK-NEXT: for (i0 = 0; ; i0++) {
Expand Down Expand Up @@ -310,14 +310,14 @@ double func6(double seed) {
}

//CHECK: void func6_grad(double seed, double *_d_seed) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
//CHECK-NEXT: double _d_arr[3] = {0};
//CHECK-NEXT: clad::array<double> arr({{3U|3UL}});
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
Expand Down Expand Up @@ -371,14 +371,14 @@ double func7(double *params) {
}

//CHECK: void func7_grad(double *params, double *_d_params) {
//CHECK-NEXT: double _d_out = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: std::size_t _d_i = 0;
//CHECK-NEXT: std::size_t i = 0;
//CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
//CHECK-NEXT: double _d_paramsPrime[1] = {0};
//CHECK-NEXT: clad::array<double> paramsPrime({{1U|1UL}});
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double _d_out = 0;
//CHECK-NEXT: double out = 0.;
//CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; ; ++i) {
Expand Down Expand Up @@ -428,10 +428,10 @@ double func8(double i, double *arr, int n) {
}

//CHECK: void func8_grad(double i, double *arr, int n, double *_d_i, double *_d_arr, int *_d_n) {
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = arr[0];
//CHECK-NEXT: arr[0] = 1;
Expand Down Expand Up @@ -478,11 +478,11 @@ double func9(double i, double j) {


//CHECK: void func9_grad(double i, double j, double *_d_i, double *_d_j) {
//CHECK-NEXT: double _d_arr[5] = {0};
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_idx = 0;
//CHECK-NEXT: int idx = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_arr[5] = {0};
//CHECK-NEXT: double arr[5] = {};
//CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (idx = 0; ; ++idx) {
Expand Down Expand Up @@ -534,12 +534,12 @@ double func10(double *arr, int n) {

//CHECK: void func10_grad_0(double *arr, int n, double *_d_arr) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; ; ++i) {
Expand Down Expand Up @@ -636,11 +636,11 @@ int main() {
}

//CHECK: void addArr_pullback(const double *arr, int n, double _d_y, double *_d_arr, int *_d_n) {
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_ret = 0;
//CHECK-NEXT: double ret = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; ; i++) {
Expand Down
2 changes: 1 addition & 1 deletion test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ double const_dot_product(double x, double y, double z) {

//CHECK: void const_dot_product_grad(double x, double y, double z, double *_d_x, double *_d_y, double *_d_z) {
//CHECK-NEXT: double _d_vars[3] = {0};
//CHECK-NEXT: double _d_consts[3] = {0};
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double _d_consts[3] = {0};
//CHECK-NEXT: double consts[3] = {1, 2, 3};
//CHECK-NEXT: {
//CHECK-NEXT: _d_vars[0] += 1 * consts[0];
Expand Down
2 changes: 1 addition & 1 deletion test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ __device__ __host__ double gauss(double* x, double* p, double sigma, int dim) {
// CHECK: void gauss_grad_1(double *x, double *p, double sigma, int dim, double *_d_p) __attribute__((device)) __attribute__((host)) {
//CHECK-NEXT: double _d_sigma = 0;
//CHECK-NEXT: int _d_dim = 0;
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
Expand All @@ -41,6 +40,7 @@ __device__ __host__ double gauss(double* x, double* p, double sigma, int dim) {
//CHECK-NEXT: double _t4;
//CHECK-NEXT: double _t5;
//CHECK-NEXT: double _t6;
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: double t = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
Expand Down
4 changes: 2 additions & 2 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ float func4(float x, float y) {
}

//CHECK: void func4_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: double _d_z = 0;
//CHECK-NEXT: float _t0;
//CHECK-NEXT: double _d_z = 0;
//CHECK-NEXT: double z = y;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = z + y;
Expand All @@ -112,8 +112,8 @@ float func5(float x, float y) {
}

//CHECK: void func5_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: int _d_z = 0;
//CHECK-NEXT: float _t0;
//CHECK-NEXT: int _d_z = 0;
//CHECK-NEXT: int z = 56;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = z + y;
Expand Down
14 changes: 7 additions & 7 deletions test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ float func(float x, float y) {
//CHECK: void func_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = x + y;
//CHECK-NEXT: _t1 = y;
//CHECK-NEXT: y = y + y++ + y;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float z = y * x;
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
Expand Down Expand Up @@ -61,9 +61,9 @@ float func2(float x, float y) {

//CHECK: void func2_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = x - y - y * y;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float z = y / x;
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
Expand Down Expand Up @@ -97,15 +97,15 @@ float func3(float x, float y) {

//CHECK: void func3_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: float _d_t = 0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = x - y - y * y;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float z = y;
//CHECK-NEXT: _t2 = y;
//CHECK-NEXT: _t1 = (y = x + x);
//CHECK-NEXT: float _d_t = 0;
//CHECK-NEXT: float t = x * z * _t1;
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: {
Expand Down Expand Up @@ -204,8 +204,8 @@ float func6(float x, float y) {
}

//CHECK: void func6_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float z = helper(x, y);
//CHECK-NEXT: _ret_value0 = z * z;
//CHECK-NEXT: {
Expand Down Expand Up @@ -262,9 +262,9 @@ float func8(float x, float y) {
}

//CHECK: void func8_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float _t0;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float z;
//CHECK-NEXT: _t0 = z;
//CHECK-NEXT: _t1 = x;
Expand Down Expand Up @@ -293,13 +293,13 @@ float func9(float x, float y) {

//CHECK: void func9_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float _t3;
//CHECK-NEXT: double _t4;
//CHECK-NEXT: float _t5;
//CHECK-NEXT: double _t7;
//CHECK-NEXT: float _t8;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: float z = helper(x, y) + helper2(x);
//CHECK-NEXT: _t3 = z;
//CHECK-NEXT: _t5 = x;
Expand Down
Loading

0 comments on commit e326768

Please sign in to comment.