From 1d56ef8997a3225d7d37cde4383de2caf54e8574 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sat, 29 Jun 2024 18:16:16 +0200 Subject: [PATCH] Add support for const cast in fwd mode --- .../Differentiator/BaseForwardModeVisitor.h | 1 + lib/Differentiator/BaseForwardModeVisitor.cpp | 19 +++++++++++++ test/ForwardMode/Pointer.C | 28 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 484d17d4b..311fc5769 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -55,6 +55,7 @@ class BaseForwardModeVisitor StmtDiff VisitCallExpr(const clang::CallExpr* CE); StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); + StmtDiff VisitCXXConstCastExpr(const clang::CXXConstCastExpr* CCE); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); StmtDiff VisitCharacterLiteral(const clang::CharacterLiteral* CL); StmtDiff VisitStringLiteral(const clang::StringLiteral* SL); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 067fd77d1..2e26444e4 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1600,6 +1600,25 @@ StmtDiff BaseForwardModeVisitor::VisitImplicitValueInitExpr( return StmtDiff(Clone(E), Clone(E)); } +StmtDiff +BaseForwardModeVisitor::VisitCXXConstCastExpr(const CXXConstCastExpr* CCE) { + StmtDiff subExprDiff = Visit(CCE->getSubExpr()); + Expr* castExpr = + m_Sema + .BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast, + CCE->getTypeInfoAsWritten(), subExprDiff.getExpr(), + CCE->getAngleBrackets(), CCE->getSourceRange()) + .get(); + Expr* castExprDiff = + m_Sema + .BuildCXXNamedCast(CCE->getBeginLoc(), tok::kw_const_cast, + CCE->getTypeInfoAsWritten(), + subExprDiff.getExpr_dx(), CCE->getAngleBrackets(), + CCE->getSourceRange()) + .get(); + return StmtDiff(castExpr, castExprDiff); +} + StmtDiff BaseForwardModeVisitor::VisitCStyleCastExpr(const CStyleCastExpr* CSCE) { StmtDiff subExprDiff = Visit(CSCE->getSubExpr()); diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index cfe3598a1..317cd28a5 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -174,6 +174,29 @@ double fn7(double i) { // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +void* cling_runtime_internal_throwIfInvalidPointer(void *Sema, void *Expr, const void *Arg) { + return const_cast(Arg); +} + +double fn8(double* params) { + double arr[] = {3.0}; + return params[0]*params[0] + *(double*)(cling_runtime_internal_throwIfInvalidPointer((void*)0UL, (void*)0UL, arr)); +} + +// CHECK: clad::ValueAndPushforward cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg); + +// CHECK: double fn8_darg0_0(double *params) { +// CHECK-NEXT: double _d_arr[1] = {0.}; +// CHECK-NEXT: double arr[1] = {3.}; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = cling_runtime_internal_throwIfInvalidPointer_pushforward((void *)0UL, (void *)0UL, arr, (void *)0UL, (void *)0UL, _d_arr); +// CHECK-NEXT: return 1. * params[0] + params[0] * 1. + *(double *)_t0.pushforward; +// CHECK-NEXT: } + +// CHECK: clad::ValueAndPushforward cling_runtime_internal_throwIfInvalidPointer_pushforward(void *Sema, void *Expr, const void *Arg, void *_d_Sema, void *_d_Expr, const void *_d_Arg) { +// CHECK-NEXT: return {const_cast(Arg), const_cast(_d_Arg)}; +// CHECK-NEXT: } + + int main() { INIT_DIFFERENTIATE(fn1, "i"); INIT_DIFFERENTIATE(fn2, "i"); @@ -190,4 +213,9 @@ int main() { TEST_DIFFERENTIATE(fn5, 3, 5); // CHECK-EXEC: {57.00} TEST_DIFFERENTIATE(fn6, 3); // CHECK-EXEC: {1.00} TEST_DIFFERENTIATE(fn7, 3); // CHECK-EXEC: {4.00} + + double params[] = {3.0}; + auto fn8_dx = clad::differentiate(fn8, "params[0]"); + double d_param = fn8_dx.execute(params); + printf("{%.2f}\n", d_param); // CHECK-EXEC: {6.00} }