Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use initialization instead of assignment when possible in the reverse mode #1013

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ namespace clad {
/// type.
static clang::QualType
getNonConstType(clang::QualType T, clang::ASTContext& C, clang::Sema& S) {
clang::Qualifiers quals(T.getQualifiers());
quals.removeConst();
return S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals);
bool isLValueRefType = T->isLValueReferenceType();
T = T.getNonReferenceType();
clang::Qualifiers quals(T.getQualifiers());
quals.removeConst();
clang::QualType nonConstType =
S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals);
if (isLValueRefType)
return C.getLValueReferenceType(nonConstType);
return nonConstType;
}
// Function to Differentiate with Clad as Backend
void DifferentiateWithClad();
Expand Down Expand Up @@ -262,14 +268,18 @@ namespace clad {
struct DelayedStoreResult {
ReverseModeVisitor& V;
StmtDiff Result;
clang::VarDecl* Declaration;
bool isConstant;
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
};

Expand Down
158 changes: 108 additions & 50 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

auto VisitBranch = [&](const Expr* Branch,
Expr* dfdx) -> std::pair<StmtDiff, StmtDiff> {
beginScope(Scope::DeclScope);
auto Result = DifferentiateSingleExpr(Branch, dfdx);
endScope();
StmtDiff BranchDiff = Result.first;
StmtDiff ExprDiff = Result.second;
Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt());
Expand Down Expand Up @@ -975,10 +977,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) {
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
beginBlock(direction::reverse);
LoopCounter loopCounter(*this);
beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope |
Scope::ContinueScope);
llvm::SaveAndRestore<Expr*> SaveCurrentBreakFlagExpr(
m_CurrentBreakFlagExpr);
m_CurrentBreakFlagExpr = nullptr;
Expand Down Expand Up @@ -2366,16 +2368,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// and restore it in the reverse pass
if (m_DiffReq.shouldBeRecorded(L)) {
StmtDiff pushPop = StoreAndRestore(LCloned);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
addToCurrentBlock(pushPop.getStmt_dx(), direction::reverse);
}

// We need to store values of derivative pointer variables in forward pass
// and restore them in reverse pass.
if (isPointerOp) {
StmtDiff pushPop = StoreAndRestore(Ldiff.getExpr_dx());
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
addToCurrentBlock(pushPop.getStmt(), direction::forward);
addToCurrentBlock(pushPop.getStmt_dx(), direction::reverse);
}

if (m_ExternalSource)
Expand Down Expand Up @@ -2564,15 +2566,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
bool promoteToFnScope =
!getCurrentScope()->isFunctionScope() &&
m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass;
QualType VDCloneType = CloneType(VD->getType());
QualType VDDerivedType = ComputeAdjointType(VDCloneType);
QualType VDCloneType;
QualType VDDerivedType;
// If the cloned declaration is moved to the function global scope,
// change its type for the corresponding adjoint type.
if (promoteToFnScope) {
VDDerivedType = ComputeAdjointType(CloneType(VD->getType()));
VDCloneType = VDDerivedType;
if (isa<ArrayType>(VDCloneType) && !isa<IncompleteArrayType>(VDCloneType))
VDCloneType =
GetCladArrayOfType(m_Context.getBaseElementType(VDCloneType));
} else {
VDCloneType = CloneType(VD->getType());
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
Expand Down Expand Up @@ -2633,7 +2639,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ComputeAdjointType(VD->getType().getNonReferenceType());
isDerivativeOfRefType = false;
}
VDDerivedInit = getZeroInit(VDDerivedType);
if (promoteToFnScope || !isDerivativeOfRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
}

// FIXME: Remove the special cases introduced by `specialThisDiffCase`
Expand Down Expand Up @@ -2735,16 +2744,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType) {
if (isDerivativeOfRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getExpr_dx());
addToCurrentBlock(pushPop.getStmt(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getStmt_dx());
}
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
}
Expand All @@ -2766,17 +2775,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
initDiff.getExpr(), VD->isDirectInit(),
nullptr, VD->getInitStyle());
if (isPointerType && derivedVDE) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
BuildDeclStmt(reverseSweepDerivativePointerE));
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
if (promoteToFnScope) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
addToCurrentBlock(assignDerivativeE, direction::forward);
if (isInsideLoop) {
auto tape = MakeCladTapeFor(derivedVDE);
addToCurrentBlock(tape.Push);
auto* reverseSweepDerivativePointerE =
BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop);
m_LoopBlock.back().push_back(
BuildDeclStmt(reverseSweepDerivativePointerE));
derivedVDE = BuildDeclRef(reverseSweepDerivativePointerE);
}
} else {
m_Sema.AddInitializerToDecl(VDDerived, initDiff.getExpr_dx(), true);
VDDerived->setInitStyle(VarDecl::InitializationStyle::CInit);
}
}
if (derivedVDE)
Expand Down Expand Up @@ -2938,7 +2952,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
decl, Clone(getArraySizeExpr(AT, m_Context, *this)), true);
decl->setInitStyle(VarDecl::InitializationStyle::CallInit);
} else {
decl->setInit(getZeroInit(VD->getType()));
m_Sema.AddInitializerToDecl(decl, getZeroInit(VD->getType()),
/*DirectInit=*/true);
decl->setInitStyle(VarDecl::InitializationStyle::CInit);
}
}
}
Expand Down Expand Up @@ -2984,7 +3000,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
if (!declsDiff.empty()) {
Stmt* DSDiff = BuildDeclStmt(declsDiff);
addToBlock(DSDiff, m_Globals);
Stmts& block =
promoteToFnScope ? m_Globals : getCurrentBlock(direction::forward);
addToBlock(DSDiff, block);
}

if (m_ExternalSource) {
Expand Down Expand Up @@ -3154,9 +3172,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return CladTape.Last();
}

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix));
Expr* Set = BuildOp(BO_Assign, Ref, E);
addToCurrentBlock(Set, direction::forward);
VarDecl* VD = BuildGlobalVarDecl(Type, prefix);
DeclStmt* decl = BuildDeclStmt(VD);
Expr* Ref = BuildDeclRef(VD);
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
addToCurrentBlock(decl, direction::forward);
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
} else {
addToBlock(decl, m_Globals);
Expr* Set = BuildOp(BO_Assign, Ref, E);
addToCurrentBlock(Set, direction::forward);
}

return Ref;
}
Expand All @@ -3170,6 +3199,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,
llvm::StringRef prefix) {
assert(E && "must be provided");
auto Type = getNonConstType(E->getType(), m_Context, m_Sema);

if (isInsideLoop) {
Expand All @@ -3184,17 +3214,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (const auto* AT = dyn_cast<ArrayType>(Type))
init = getArraySizeExpr(AT, m_Context, *this);

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init));
if (E) {
Expr* Store = BuildOp(BO_Assign, Ref, Clone(E));
Expr* Restore = nullptr;
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid) {
auto* r = Clone(E);
Restore = BuildOp(BO_Assign, r, Ref);
}
return {Store, Restore};
VarDecl* VD = BuildGlobalVarDecl(Type, prefix, init);
DeclStmt* decl = BuildDeclStmt(VD);
Expr* Ref = BuildDeclRef(VD);
Stmt* Store = nullptr;
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
if (isFnScope) {
Store = decl;
m_Sema.AddInitializerToDecl(VD, E, /*DirectInit=*/true);
VD->setInitStyle(VarDecl::InitializationStyle::CInit);
} else {
addToBlock(decl, m_Globals);
Store = BuildOp(BO_Assign, Ref, Clone(E));
}
return {};

Stmt* Restore = nullptr;
if (E->isModifiableLvalue(m_Context) == Expr::MLV_Valid)
Restore = BuildOp(BO_Assign, Clone(E), Ref);

return {Store, Restore};
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
Expand All @@ -3204,6 +3243,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get());
} else if (isFnScope) {
V.m_Sema.AddInitializerToDecl(Declaration, New, true);
Declaration->setInitStyle(VarDecl::InitializationStyle::CInit);
V.addToCurrentBlock(V.BuildDeclStmt(Declaration), direction::forward);
} else {
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New),
direction::forward);
Expand All @@ -3219,26 +3262,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this, Ediff,
/*isConstant*/ isConst,
/*isInsideLoop*/ false,
return DelayedStoreResult{*this,
Ediff,
/*Declaration=*/nullptr,
/*isConstant=*/isConst,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/false};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/true};
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
/*Declaration=*/nullptr,
/*isConstant=*/false,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true};
}
Expr* Ref = BuildDeclRef(GlobalStoreImpl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix));
bool isFnScope = getCurrentScope()->isFunctionScope() ||
m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass;
VarDecl* VD = BuildGlobalVarDecl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix);
Expr* Ref = BuildDeclRef(VD);
if (!isFnScope)
addToBlock(BuildDeclStmt(VD), m_Globals);
// Return reference to the declaration instead of original expression.
return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref},
/*isConstant*/ false,
/*isInsideLoop*/ false, /*pNeedsUpdate=*/true};
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
/*Declaration=*/VD,
/*isConstant=*/false,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
/*pNeedsUpdate=*/true};
}

ReverseModeVisitor::LoopCounter::LoopCounter(ReverseModeVisitor& RMV)
Expand Down Expand Up @@ -3310,7 +3369,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitDoStmt(const DoStmt* DS) {

beginBlock(direction::reverse);
LoopCounter loopCounter(*this);

Expand Down
Loading
Loading