Skip to content

Commit

Permalink
Add support for const cast in fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Jun 29, 2024
1 parent 4c824ed commit 1d56ef8
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BaseForwardModeVisitor
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXConstCastExpr(const clang::CXXConstCastExpr* CCE);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL);
StmtDiff VisitStringLiteral(const clang::StringLiteral* SL);
Expand Down
19 changes: 19 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,25 @@ StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr(
return StmtDiff(Clone(E), Clone(E));
}

StmtDiff
BaseForwardModeVisitor::VisitCXXConstCastExpr(const CXXConstCastExpr* CCE) {
StmtDiff subExprDiff = Visit(CCE->getSubExpr());
Expr* castExpr =
m_Sema
.BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast,
CCE->getTypeInfoAsWritten(), subExprDiff.getExpr(),
CCE->getAngleBrackets(), CCE->getSourceRange())
.get();
Expr* castExprDiff =
m_Sema
.BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast,
CCE->getTypeInfoAsWritten(),
subExprDiff.getExpr_dx(), CCE->getAngleBrackets(),
CCE->getSourceRange())
.get();
return StmtDiff(castExpr, castExprDiff);
}

StmtDiff
BaseForwardModeVisitor::VisitCStyleCastExpr(const CStyleCastExpr* CSCE) {
StmtDiff subExprDiff = Visit(CSCE->getSubExpr());
Expand Down
28 changes: 28 additions & 0 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,29 @@ double fn7(double i) {
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

void* cling_runtime_internal_throwIfInvalidPointer(void *Sema, void *Expr, const void *Arg) {
return const_cast<void*>(Arg);
}

double fn8(double* params) {
double arr[] = {3.0};
return params[0]*params[0] + *(double*)(cling_runtime_internal_throwIfInvalidPointer((void*)0UL, (void*)0UL, arr));
}

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg);

// CHECK: double fn8_darg0_0(double *params) {
// CHECK-NEXT: double _d_arr[1] = {0.};
// CHECK-NEXT: double arr[1] = {3.};
// CHECK-NEXT: clad::ValueAndPushforward<void *, void *> _t0 = cling_runtime_internal_throwIfInvalidPointer_pushforward((void *)0UL, (void *)0UL, arr, (void *)0UL, (void *)0UL, _d_arr);
// CHECK-NEXT: return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward;
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<void *, void *> cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) {
// CHECK-NEXT: return {const_cast<void *>(Arg), const_cast<void *>(_d_Arg)};
// CHECK-NEXT: }


int main() {
INIT_DIFFERENTIATE(fn1, "i");
INIT_DIFFERENTIATE(fn2, "i");
Expand All @@ -190,4 +213,9 @@ int main() {
TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00}
TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00}
TEST_DIFFERENTIATE(fn7, 3); // CHECK-EXEC: {4.00}

double params[] = {3.0};
auto fn8_dx = clad::differentiate(fn8, "params[0]");
double d_param = fn8_dx.execute(params);
printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00}
}

0 comments on commit 1d56ef8

Please sign in to comment.