Skip to content

Commit

Permalink
Fix storing/restoring in loops.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 18, 2023
1 parent 697624d commit 8692557
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 56 deletions.
7 changes: 5 additions & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ namespace clad {
forceDeclCreation, IS);
}

/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;

/// For an expr E, decides if it is useful to store it in a global temporary
/// variable and replace E's further usage by a reference to that variable
/// to avoid recomputiation.
Expand Down Expand Up @@ -584,8 +589,6 @@ namespace clad {
m_BreakContStmtHandlers.pop_back();
}

std::map<clang::SourceLocation, bool> m_ToBeRecorded;

/// Registers an external RMV source.
///
/// Multiple external RMV source can be registered by calling this function
Expand Down
8 changes: 4 additions & 4 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef CLAD_TBR_ANALYZER_H
#define CLAD_TBR_ANALYZER_H
#ifndef CLAD_DIFFERENTIATOR_TBRANALYZER_H
#define CLAD_DIFFERENTIATOR_TBRANALYZER_H

#include "clang/AST/StmtVisitor.h"
#include "clad/Differentiator/CladUtils.h"
Expand All @@ -17,7 +17,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
/// type keys.
struct APIntHash {
size_t operator()(const llvm::APInt& apint) const {
return std::hash<std::string>{}(apint.toString(10, true));
return std::hash<std::string>{}(apint.toString(2, true));
}
};

Expand Down Expand Up @@ -301,4 +301,4 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
};

} // end namespace clad
#endif // CLAD_TBR_ANALYZER_H
#endif // CLAD_DIFFERENTIATOR_TBRANALYZER_H
28 changes: 2 additions & 26 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,9 @@ void ErrorEstimationHandler::BuildFinalErrorStmt() {
void ErrorEstimationHandler::AddErrorStmtToBlock(Expr* var, Expr* deltaVar,
Expr* errorExpr,
bool isInsideLoop /*=false*/) {

if (auto ASE = dyn_cast<ArraySubscriptExpr>(var)) {
// If inside loop, the index has been pushed twice
// (once by ArraySubscriptExpr and the second time by us)
// pop and store it in a temporary variable to reuse later.
// FIXME: build add assign into he same expression i.e.
// _final_error += _delta_arr[pop(_t0)] += <-Error Expr->
// to avoid storage of the pop value.
Expr* popVal = ASE->getIdx();
if (isInsideLoop) {
LookupResult& Pop = m_RMV->GetCladTapePop();
CXXScopeSpec CSS;
CSS.Extend(m_RMV->m_Context, m_RMV->GetCladNamespace(), noLoc, noLoc);
auto PopDRE = m_RMV->m_Sema
.BuildDeclarationNameExpr(CSS, Pop,
/*AcceptInvalidDecl=*/false)
.get();
Expr* tapeRef = dyn_cast<CallExpr>(popVal)->getArg(0);
popVal = m_RMV->m_Sema
.ActOnCallExpr(m_RMV->getCurrentScope(), PopDRE, noLoc,
tapeRef, noLoc)
.get();
popVal = m_RMV->StoreAndRef(popVal, direction::reverse);
}
// If the variable declration refers to an array element
// create the suitable _delta_arr[i] (because we have not done
// this before).
deltaVar = getArraySubscriptExpr(deltaVar, popVal);
deltaVar = getArraySubscriptExpr(deltaVar, ASE->getIdx());
m_RMV->addToCurrentBlock(m_RMV->BuildOp(BO_AddAssign, deltaVar, errorExpr),
direction::reverse);
// immediately emit fin_err += delta_[].
Expand Down
55 changes: 35 additions & 20 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,9 +1516,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// may be changed since we have no way to determine otherwise.
// FIXME: We cannot use GlobalStoreAndRef to store a whole array so now
// arrays are not stored.
StmtDiff argDiffStore = GlobalStoreAndRef(
argDiff.getExpr(), "_t",
/*force=*/passByRef && !argDiff.getExpr()->getType()->isArrayType());
StmtDiff argDiffStore;
if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) {
argDiffStore =
GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true);
} else {
argDiffStore = {argDiff.getExpr(), argDiff.getExpr()};
}

// We need to pass the actual argument in the cloned call expression,
// instead of a temporary, for arguments passed by reference. This is
// because, callee function may modify the argument passed as reference
Expand Down Expand Up @@ -1928,24 +1933,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == UO_PostInc || opCode == UO_PostDec) {
diff = Visit(E, dfdx());
auto EStored = GlobalStoreAndRef(diff.getExpr());
auto assign =
BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
if (EStored.getExpr() != diff.getExpr()) {
auto assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(),
EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
}

ResultRef = diff.getExpr_dx();
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff);
} else if (opCode == UO_PreInc || opCode == UO_PreDec) {
diff = Visit(E, dfdx());
auto EStored = GlobalStoreAndRef(diff.getExpr());
auto assign =
BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(), EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);

if (EStored.getExpr() != diff.getExpr()) {
auto assign = BuildOp(BinaryOperatorKind::BO_Assign, diff.getExpr(),
EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
}
} else if (opCode == UnaryOperatorKind::UO_Real ||
opCode == UnaryOperatorKind::UO_Imag) {
diff = VisitWithExplicitNoDfDx(E);
Expand Down Expand Up @@ -2675,14 +2683,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return UsefulToStoreGlobal(UO->getSubExpr());
return true;
}
if (isa<ArraySubscriptExpr>(B)) {
auto ASE = cast<ArraySubscriptExpr>(B);
return UsefulToStoreGlobal(ASE->getBase()) || UsefulToStoreGlobal(ASE->getIdx());
}
// We lack context to decide if this is useful to store or not. In the
// current system that should have been decided by the parent expression.
// FIXME: Here will be the entry point of the advanced activity analysis.
if (isa<DeclRefExpr>(B)) {
if (isa<DeclRefExpr>(B) /* || isa<ArraySubscriptExpr>(B)*/) {
// auto line =
// m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getLine();
// auto column =
Expand Down Expand Up @@ -2998,6 +3002,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock));
m_LoopBlock.pop_back();

/// Increment statement in the for-loop is only executed if the iteration
/// did not end with a break/continue statement. Therefore, forLoopIncDiff
/// should be inside the last switch case in the reverse pass.
if (forLoopIncDiff) {
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff));
} else {
bodyDiff.updateStmtDx(forLoopIncDiff);
}
}

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();
Expand All @@ -3019,7 +3035,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(counterDecrement, direction::reverse);
addToCurrentBlock(condVarDiff, direction::reverse);
addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse);
addToCurrentBlock(forLoopIncDiff, direction::reverse);
bodyDiff = {bodyDiff.getStmt(),
unwrapIfSingleStmt(endBlock(direction::reverse))};
return bodyDiff;
Expand Down
14 changes: 12 additions & 2 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,19 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) {
pair.second->merge(mergeData->val.objData[pair.first]);
}
} else if (this->type == ARR_TYPE) {
/// FIXME: Currently non-constant indices are not supported in merging.
for (auto pair : this->val.arrData) {
pair.second->merge(mergeData->val.arrData[pair.first]);
auto it = mergeData->val.arrData.find(pair.first);
if (it != mergeData->val.arrData.end()) {
pair.second->merge(it->second);
}
}
for (auto pair : mergeData->val.arrData) {
auto it = this->val.arrData.find(pair.first);
if (it == mergeData->val.arrData.end()) {
std::unordered_map<VarData*, VarData*> refVars;
this->val.arrData[pair.first] = pair.second->copy(refVars);
}
}
} else if (this->type == REF_TYPE && this->val.refData) {
this->val.refData->merge(mergeData->val.refData);
Expand Down Expand Up @@ -757,7 +768,6 @@ void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) {
/// First pass
innermostLoopBranch = reqStack.size() - 2;
firstLoopPass = true;
mergeCurBranchTo(innermostLoopBranch - 1);
if (body)
Visit(body);
if (deleteCurBranch) {
Expand Down
4 changes: 2 additions & 2 deletions test/Misc/RunDemos.C
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

// CHECK_CUSTOM_MODEL-NOT: Could not load {{.*}}cladCustomModelPlugin{{.*}}

// RUN: ./CustomModelTest.out | FileCheck -check-prefix CHECK_CUSTOM_MODEL_EXEC %s
// RUN: ./CustomModelTest.out
// CHECK_CUSTOM_MODEL_EXEC-NOT:{{.*error|warning|note:.*}}
// CHECK_CUSTOM_MODEL_EXEC: The code is:
// CHECK_CUSTOM_MODEL_EXEC-NEXT: void func_grad(float x, float y, clad::array_ref<float> _d_x, clad::array_ref<float> _d_y, double &_final_error) {
Expand Down Expand Up @@ -188,7 +188,7 @@

// CHECK_PRINT_MODEL-NOT: Could not load {{.*}}cladPrintModelPlugin{{.*}}

// RUN: ./PrintModelTest.out | FileCheck -check-prefix CHECK_PRINT_MODEL_EXEC %s
// RUN: ./PrintModelTest.out
// CHECK_PRINT_MODEL_EXEC-NOT:{{.*error|warning|note:.*}}
// CHECK_PRINT_MODEL_EXEC: The code is:
// CHECK_PRINT_MODEL_EXEC-NEXT: void func_grad(float x, float y, clad::array_ref<float> _d_x, clad::array_ref<float> _d_y, double &_final_error) {
Expand Down

0 comments on commit 8692557

Please sign in to comment.