Skip to content

Commit

Permalink
Add support for std::initializer_list in the reverse mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 15, 2024
1 parent 015a389 commit e1e9ac1
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ namespace clad {
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS);
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
Expand Down
57 changes: 48 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(enzymeCall);
}
}

StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
diag(
DiagnosticsEngine::Warning, S->getBeginLoc(),
Expand Down Expand Up @@ -1479,7 +1485,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// global. Ref-type declarations cannot be moved to the function global
// scope because they can't be separated from their inits.
if (DRE->getDecl()->getType()->isReferenceType() &&
!VD->getType()->isReferenceType())
VD->getType()->isPointerType())
clonedDRE = BuildOp(UO_Deref, clonedDRE);
if (m_DiffReq.Mode == DiffMode::jacobian) {
if (m_VectorOutput.size() <= outputArrayCursor)
Expand Down Expand Up @@ -2711,11 +2717,43 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDCloneType = CloneType(VD->getType());
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}
bool isDerivativeOfRefType = VD->getType()->isReferenceType();

bool isRefType = VD->getType()->isLValueReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;
std::string typeName;
if (const auto* RT =
utils::GetValueType(VD->getType())->getAs<RecordType>())
typeName = RT->getDecl()->getNameAsString();
if (typeName == "initializer_list") {
if (VD->getInit()) {
if (const auto* CXXILE = dyn_cast<CXXStdInitializerListExpr>(
VD->getInit()->IgnoreImplicit())) {
if (const auto* ILE = dyn_cast<InitListExpr>(
CXXILE->getSubExpr()->IgnoreImplicit())) {
VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType());
unsigned numInits = ILE->getNumInits();
VDDerivedInit = ConstantFolder::synthesizeLiteral(
m_Context.getSizeType(), m_Context, numInits);
VDCloneType = VDDerivedType;
}
} else if (isRefType) {
initDiff = Visit(VD->getInit());
if (promoteToFnScope) {
VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx());
VDDerivedType = VDDerivedInit->getType();
} else {
VDDerivedInit = initDiff.getExpr_dx();
VDDerivedType =
m_Context.getLValueReferenceType(VDDerivedInit->getType());
}
VDCloneType = VDDerivedType;
}
}
}

// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;
Expand Down Expand Up @@ -2746,7 +2784,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
VDDerivedInit = getZeroInit(VD->getType());
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VD->getType());

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
Expand All @@ -2763,14 +2802,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (isDerivativeOfRefType) {
if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType =
ComputeAdjointType(VD->getType().getNonReferenceType());
isDerivativeOfRefType = false;
isRefType = false;
}
if (promoteToFnScope || !isDerivativeOfRefType)
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
Expand Down Expand Up @@ -2827,7 +2866,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
if (!isRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = nullptr;

if (!clad::utils::hasNonDifferentiableAttribute(VD)) {
Expand Down Expand Up @@ -2876,7 +2915,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType && promoteToFnScope) {
if (isRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand All @@ -2898,7 +2937,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ->
// double* ref;
// ref = &x;
if (isDerivativeOfRefType && promoteToFnScope)
if (isRefType && promoteToFnScope)
VDClone = BuildGlobalVarDecl(
VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()),
Expand Down
44 changes: 44 additions & 0 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,49 @@ double f22(double x, double y) {
return t;
}

double f23(double x, double y) {
auto&& list = {1., x+y};
double res = 5;
if (x > y) {
auto& ref = list;
res = *(std::end(ref) - 1);
}
return res;
}

//CHECK: void f23_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: clad::array<double> *_d_ref = 0;
//CHECK-NEXT: clad::array<double> *ref = {};
//CHECK-NEXT: double _t0;
//CHECK-NEXT: clad::array<double> _d_list = 2UL;
//CHECK-NEXT: clad::array<double> list = {1., x + y};
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 5;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _d_ref = &_d_list;
//CHECK-NEXT: ref = &list;
//CHECK-NEXT: _t0 = res;
//CHECK-NEXT: res = *(std::end(*ref) - 1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: {
//CHECK-NEXT: res = _t0;
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res = 0;
//CHECK-NEXT: *(std::end(*_d_ref) - 1) += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += _d_list[1];
//CHECK-NEXT: *_d_y += _d_list[1];
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -841,4 +884,5 @@ int main() {
TEST(f20, 1, 2); // CHECK-EXEC: {0.00, 3.00}
TEST(f21, 6, 4); // CHECK-EXEC: {1.00, 0.00}
TEST(f22, 6, 4); // CHECK-EXEC: {0.00, 0.00}
TEST(f23, 7, 5); // CHECK-EXEC: {1.00, 1.00}
}
115 changes: 115 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -3020,6 +3020,119 @@ double fn37(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn38(double x, double y) {
double sum = 0;
if (x > 0) {
auto&& range = {1., x, 2., y, 3.};
for (auto elem : range)
sum += elem;
}
return sum;
}

//CHECK: void fn38_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: clad::array<double> _d_range = {{5U|5UL}};
//CHECK-NEXT: clad::array<double> range = {};
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: range = {1., x, 2., y, 3.};
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: clad::array<double> &__range20 = range;
//CHECK-NEXT: clad::array<double> &_d___range2 = _d_range;
//CHECK-NEXT: {{const double *\*|const_iterator }}__begin20 = std::begin(__range20);
//CHECK-NEXT: double *_d___begin2 = std::begin(_d___range2);
//CHECK-NEXT: {{const double *\*|const_iterator }}__end20 = std::end(__range20);
//CHECK-NEXT: double _d_elem = 0;
//CHECK-NEXT: double elem = 0;
//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) {
//CHECK-NEXT: {
//CHECK-NEXT: _d_elem = *_d___begin2;
//CHECK-NEXT: elem = *__begin20;
//CHECK-NEXT: clad::push(_t2, elem);
//CHECK-NEXT: clad::push(_t3, _d_elem);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += elem;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: {
//CHECK-NEXT: _d___begin2--;
//CHECK-NEXT: elem = clad::pop(_t2);
//CHECK-NEXT: _d_elem = clad::pop(_t3);
//CHECK-NEXT: }
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_elem += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: *_d___begin2 += _d_elem;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += _d_range[1];
//CHECK-NEXT: *_d_y += _d_range[3];
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn39(double x) {
double res = 0;
auto &&range = {1, 2, 3};
for (auto i = range.begin(); i != range.end(); i++) {
res += x * (*i);
}
return res;
}

//CHECK: void fn39_grad(double x, double *_d_x) {
//CHECK-NEXT: int *_d_i = 0;
//CHECK-NEXT: {{const int *\*|const_iterator }}i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: clad::array<int> _d_range = {{3U|3UL}};
//CHECK-NEXT: clad::array<int> range = {1, 2, 3};
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
//CHECK-NEXT: _d_i = std::begin(_d_range);
//CHECK-NEXT: for (i = std::begin(range); ; _d_i++ , i++) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i != std::end(range)))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, res);
//CHECK-NEXT: res += x * (*i);
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: i--;
//CHECK-NEXT: _d_i--;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: *_d_x += _r_d0 * (*i);
//CHECK-NEXT: *_d_i += x * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
Expand Down Expand Up @@ -3109,6 +3222,8 @@ int main() {
TEST_2(fn35, 2, 2); // CHECK-EXEC: {12.00, 4.00}
TEST_2(fn36, 1, 1); // CHECK-EXEC: {1.75, 0.00}
TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00}
TEST_2(fn38, 6, 3); // CHECK-EXEC: {1.00, 1.00}
TEST(fn39, 9); // CHECK-EXEC: {6.00}
}

//CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {
Expand Down

0 comments on commit e1e9ac1

Please sign in to comment.