Skip to content

Commit

Permalink
Remove redundant assignments and support non-array type ranges in ran…
Browse files Browse the repository at this point in the history
…ge-based for loops
  • Loading branch information
PetroZarytskyi committed Aug 20, 2024
1 parent 0b5174e commit a41cc56
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 68 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ namespace clad {
StmtDiff VisitCaseStmt(const clang::CaseStmt* CS);
StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS);
DeclDiff<clang::VarDecl> DifferentiateVarDecl(const clang::VarDecl* VD,
bool AddToBlock = true);
bool keepLocal = false);
StmtDiff VisitSubstNonTypeTemplateParmExpr(
const clang::SubstNonTypeTemplateParmExpr* NTTP);
StmtDiff
Expand Down
42 changes: 14 additions & 28 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff
ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) {
const auto* RangeDecl = cast<VarDecl>(FRS->getRangeStmt()->getSingleDecl());
const auto* BeginDecl = cast<VarDecl>(FRS->getBeginStmt()->getSingleDecl());
DeclDiff<VarDecl> VisitRange =
DifferentiateVarDecl(RangeDecl, /*keepLocal=*/true);
DeclDiff<VarDecl> VisitBegin =
DifferentiateVarDecl(BeginDecl, /*keepLocal=*/true);

beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Expand All @@ -991,33 +998,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SaveAndRestore<bool> SaveIsInside(isInsideLoop,
/*NewValue=*/false);

const auto* RangeDecl = cast<VarDecl>(FRS->getRangeStmt()->getSingleDecl());
const auto* BeginDecl = cast<VarDecl>(FRS->getBeginStmt()->getSingleDecl());

DeclDiff<VarDecl> VisitRange = DifferentiateVarDecl(RangeDecl, false);
DeclDiff<VarDecl> VisitBegin = DifferentiateVarDecl(BeginDecl, false);

beginBlock(direction::reverse);
// Create all declarations needed.
DeclRefExpr* beginDeclRef = BuildDeclRef(VisitBegin.getDecl());
Expr* d_beginDeclRef = m_Variables[beginDeclRef->getDecl()];
DeclRefExpr* rangeDeclRef = BuildDeclRef(VisitRange.getDecl());
Expr* d_rangeDeclRef = m_Variables[rangeDeclRef->getDecl()];

Expr* rangeInit = Clone(FRS->getRangeInit());
Expr* d_rangeInitDeclRef =
m_Variables[cast<DeclRefExpr>(rangeInit)->getDecl()];
VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_rangeInitDeclRef));
Expr* assignAdjBegin = BuildOp(BO_Assign, d_beginDeclRef, d_rangeDeclRef);
Expr* assignRange =
BuildOp(BO_Assign, rangeDeclRef, BuildOp(UO_AddrOf, rangeInit));

addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl()));
addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx()));
addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl()));
addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx()));
addToCurrentBlock(assignAdjBegin);
addToCurrentBlock(assignRange);

const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());
QualType endType = CloneType(EndDecl->getType());
Expand Down Expand Up @@ -2718,7 +2706,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

DeclDiff<VarDecl> ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD,
bool addToBlock) {
bool keepLocal) {
StmtDiff initDiff;
Expr* VDDerivedInit = nullptr;

Expand All @@ -2728,7 +2716,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// declarations don't have to be moved to the function global scope.
bool promoteToFnScope =
!getCurrentScope()->isFunctionScope() &&
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass;
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass && !keepLocal;
QualType VDCloneType;
QualType VDDerivedType;
// If the cloned declaration is moved to the function global scope,
Expand Down Expand Up @@ -2896,7 +2884,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
getZeroInit(VDDerivedType));
else
assignToZero = GetCladZeroInit(declRef);
if (addToBlock)
if (!keepLocal)
addToCurrentBlock(assignToZero, direction::reverse);
}
}
Expand All @@ -2913,11 +2901,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
if (addToBlock)
addToCurrentBlock(assignDerivativeE);
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
if (addToBlock)
if (!keepLocal)
addToCurrentBlock(pushPop.getStmt(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getStmt_dx());
}
Expand All @@ -2944,11 +2931,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
if (addToBlock)
addToCurrentBlock(assignDerivativeE, direction::forward);
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
if (addToBlock)
if (!keepLocal)
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
Expand Down
68 changes: 29 additions & 39 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2702,14 +2702,12 @@ double fn34(double x, double y){
//CHECK-NEXT: double r = 0;
//CHECK-NEXT: double _d_a[3] = {0};
//CHECK-NEXT: double a[3] = {y, x * y, x * x + y};
//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (*__range10)[3] = &a;
//CHECK-NEXT: double (*_d___range1)[3] = &_d_a;
//CHECK-NEXT: double *__begin10 = *__range10;
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: _d___begin1 = *_d___range1;
//CHECK-NEXT: __range10 = &a;
//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}};
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (&__range10)[3] = a;
//CHECK-NEXT: double (&_d___range1)[3] = _d_a;
//CHECK-NEXT: double *__begin10 = __range10;
//CHECK-NEXT: double *_d___begin1 = _d___range1;
//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}};
//CHECK-NEXT: double *_d_i = 0;
//CHECK-NEXT: double *i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
Expand Down Expand Up @@ -2778,14 +2776,12 @@ double fn35(double x, double y){
//CHECK-NEXT: double r = 0;
//CHECK-NEXT: double _d_a[3] = {0};
//CHECK-NEXT: double a[3] = {x, x * y, 0};
//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (*__range10)[3] = &a;
//CHECK-NEXT: double (*_d___range1)[3] = &_d_a;
//CHECK-NEXT: double *__begin10 = *__range10;
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: _d___begin1 = *_d___range1;
//CHECK-NEXT: __range10 = &a;
//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}};
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (&__range10)[3] = a;
//CHECK-NEXT: double (&_d___range1)[3] = _d_a;
//CHECK-NEXT: double *__begin10 = __range10;
//CHECK-NEXT: double *_d___begin1 = _d___range1;
//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}};
//CHECK-NEXT: double *_d_i = 0;
//CHECK-NEXT: double *i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
Expand All @@ -2797,13 +2793,11 @@ double fn35(double x, double y){
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, {{0U|0UL|0ULL}});
//CHECK-NEXT: double (*__range20)[3] = &a;
//CHECK-NEXT: double (*_d___range2)[3] = &_d_a;
//CHECK-NEXT: double *__begin20 = *__range20;
//CHECK-NEXT: double *_d___begin2 = 0;
//CHECK-NEXT: _d___begin2 = *_d___range2;
//CHECK-NEXT: __range20 = &a;
//CHECK-NEXT: double *__end20 = *__range20 + {{3|3L}};
//CHECK-NEXT: double (&__range20)[3] = a;
//CHECK-NEXT: double (&_d___range2)[3] = _d_a;
//CHECK-NEXT: double *__begin20 = __range20;
//CHECK-NEXT: double *_d___begin2 = _d___range2;
//CHECK-NEXT: double *__end20 = __range20 + {{3|3L}};
//CHECK-NEXT: double *_d_j = 0;
//CHECK-NEXT: double *j = 0;
//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) {
Expand Down Expand Up @@ -2906,14 +2900,12 @@ double fn36(double x, double y){
//CHECK-NEXT: double a[3] = {1, 2, 3};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (*__range10)[3] = &a;
//CHECK-NEXT: double (*_d___range1)[3] = &_d_a;
//CHECK-NEXT: double *__begin10 = *__range10;
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: _d___begin1 = *_d___range1;
//CHECK-NEXT: __range10 = &a;
//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}};
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (&__range10)[3] = a;
//CHECK-NEXT: double (&_d___range1)[3] = _d_a;
//CHECK-NEXT: double *__begin10 = __range10;
//CHECK-NEXT: double *_d___begin1 = _d___range1;
//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}};
//CHECK-NEXT: double _d_i = 0;
//CHECK-NEXT: double i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
Expand Down Expand Up @@ -2989,14 +2981,12 @@ double fn37(double x, double y) {
//CHECK-NEXT: double range[3] = {x, 4., y};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: unsigned {{int|long|long long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (*__range10)[3] = &range;
//CHECK-NEXT: double (*_d___range1)[3] = &_d_range;
//CHECK-NEXT: double *__begin10 = *__range10;
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: _d___begin1 = *_d___range1;
//CHECK-NEXT: __range10 = &range;
//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}};
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL|0ULL}};
//CHECK-NEXT: double (&__range10)[3] = range;
//CHECK-NEXT: double (&_d___range1)[3] = _d_range;
//CHECK-NEXT: double *__begin10 = __range10;
//CHECK-NEXT: double *_d___begin1 = _d___range1;
//CHECK-NEXT: double *__end10 = __range10 + {{3|3L}};
//CHECK-NEXT: double _d_elem = 0;
//CHECK-NEXT: double elem = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
Expand Down

0 comments on commit a41cc56

Please sign in to comment.