Skip to content

Commit

Permalink
Add support for std::initializer_list in the reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 1, 2024
1 parent 7d1e26c commit f76201c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ namespace clad {
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS);
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
Expand Down
36 changes: 28 additions & 8 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(enzymeCall);
}
}

StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
diag(
DiagnosticsEngine::Warning, S->getBeginLoc(),
Expand Down Expand Up @@ -2705,7 +2711,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDCloneType = CloneType(VD->getType());
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
if (VD->getInit())
if (const auto* CXXILE = dyn_cast<CXXStdInitializerListExpr>(
VD->getInit()->IgnoreImplicit()))
if (const auto* ILE =
dyn_cast<InitListExpr>(CXXILE->getSubExpr()->IgnoreImplicit()))
if (VD->getType()->isRValueReferenceType()) {
VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType());
unsigned numInits = ILE->getNumInits();
VDDerivedInit = ConstantFolder::synthesizeLiteral(
m_Context.getSizeType(), m_Context, numInits);
if (promoteToFnScope)
VDCloneType = VDDerivedType;
}
bool isRefType = VD->getType()->isLValueReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
Expand Down Expand Up @@ -2740,7 +2759,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
VDDerivedInit = getZeroInit(VD->getType());
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VD->getType());

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
Expand All @@ -2757,14 +2777,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (isDerivativeOfRefType) {
if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType =
ComputeAdjointType(VD->getType().getNonReferenceType());
isDerivativeOfRefType = false;
isRefType = false;
}
if (promoteToFnScope || !isDerivativeOfRefType)
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
Expand Down Expand Up @@ -2821,7 +2841,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
if (!isRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = nullptr;

if (!clad::utils::hasNonDifferentiableAttribute(VD)) {
Expand Down Expand Up @@ -2869,7 +2889,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType && promoteToFnScope) {
if (isRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand All @@ -2890,7 +2910,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ->
// double* ref;
// ref = &x;
if (isDerivativeOfRefType && promoteToFnScope)
if (isRefType && promoteToFnScope)
VDClone = BuildGlobalVarDecl(
VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()),
Expand Down

0 comments on commit f76201c

Please sign in to comment.