From ac24906f915addd882c47b52635e334dc7e68ba7 Mon Sep 17 00:00:00 2001 From: Rohan Julka Date: Wed, 19 Jun 2024 22:58:04 +0100 Subject: [PATCH] Add validation to prevent error on empty if block --- .../clad/Differentiator/ReverseModeVisitor.h | 4 +- test/Gradient/Gradients.C | 94 +++++++++++++++++-- 2 files changed, 89 insertions(+), 9 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 7475a383b..11673bb63 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -405,7 +405,9 @@ namespace clad { const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff VisitCXXNullPtrLiteralExpr(const clang::CXXNullPtrLiteralExpr* NPE); - StmtDiff VisitNullStmt(const clang::NullStmt* NS) { return StmtDiff{}; }; + StmtDiff VisitNullStmt(const clang::NullStmt* NS) { + return StmtDiff{Clone(NS), Clone(NS)}; + }; static DeclDiff DifferentiateStaticAssertDecl(const clang::StaticAssertDecl* SAD); diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index 1362b268a..6c09c4038 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -832,16 +832,20 @@ double fn_cond_init(double x) { //CHECK-NEXT: } double fn_null_stmts(double x) { - ;;;;;;;;;;;;;;;;; - ;;;;;return x;;;; - ;;;;;;;;;;;;;;;;; + return x; + ; + ; } // = x -//CHECK: void fn_null_stmts_grad(double x, double *_d_x) { -//CHECK-NEXT: goto _label0; -//CHECK-NEXT: _label0: -//CHECK-NEXT: *_d_x += 1; -//CHECK-NEXT: } +//CHECK: void fn_null_stmts_grad(double x, double *_d_x) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: ; +//CHECK-NEXT: ; +//CHECK-NEXT: ; +//CHECK-NEXT: ; +//CHECK-NEXT: _label0: +//CHECK-NEXT: *_d_x += 1; +//CHECK-NEXT:} double fn_const_cond_op(double x) { return x + (x > 0 ? 1.0 : 0.0); @@ -853,6 +857,73 @@ double fn_const_cond_op(double x) { //CHECK-NEXT: *_d_x += 1; //CHECK-NEXT: } + +double fn_empty_if_block(double x) { + double res = 0; + if (x > 0) + ; + return res; +} + +//CHECK:void fn_empty_if_block_grad(double x, double *_d_x) { +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: { +//CHECK-NEXT: _cond0 = x > 0; +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: ; +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: ; +//CHECK-NEXT:} + +double fn_empty_if_else(double x) { + double res = 0; + if ((res = 0)) + ; else { + res = 5 * x; + } + return res; +} + +//CHECK: void fn_empty_if_else_grad(double x, double *_d_x) { +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double _t0; +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: double _t1; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: { +//CHECK-NEXT: _t0 = res; +//CHECK-NEXT: _cond0 = (res = 0); +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: ; +//CHECK-NEXT: else { +//CHECK-NEXT: _t1 = res; +//CHECK-NEXT: res = 5 * x; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: { +//CHECK-NEXT: if (_cond0) +//CHECK-NEXT: ; +//CHECK-NEXT: else { +//CHECK-NEXT: { +//CHECK-NEXT: res = _t1; +//CHECK-NEXT: double _r_d1 = _d_res; +//CHECK-NEXT: _d_res -= _r_d1; +//CHECK-NEXT: *_d_x += 5 * _r_d1; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = _t0; +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: _d_res -= _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT:} + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -928,4 +999,11 @@ int main() { INIT_GRADIENT(fn_const_cond_op); TEST_GRADIENT(fn_const_cond_op, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 1.00 + + INIT_GRADIENT(fn_empty_if_block); + TEST_GRADIENT(fn_empty_if_block, /*numOfDerivativeArgs=*/1, 0, &dx); // CHECK-EXEC: 0.00 + + INIT_GRADIENT(fn_empty_if_else); + TEST_GRADIENT(fn_empty_if_else, /*numOfDerivativeArgs=*/1, 1, &dx); // CHECK-EXEC: 5.00 + }