From 55057643dc7bf25ef86658495cc15cacd5abc85c Mon Sep 17 00:00:00 2001 From: Nirhar Date: Sun, 5 Jun 2022 11:11:14 +0530 Subject: [PATCH] Patch for the failure of functions with void return type and with return statement Prior to this patch, clad would throw an assertion error when attempting to differentiate void functions that contain a return statement. The assertion error would indicate that there was an attempt to differentiate a nullptr. This Patch checks if the there is an expression associsted with the null pointer and only if there is one, we differentiate it. Added tests for the same This patch fixes #432 --- lib/Differentiator/ForwardModeVisitor.cpp | 4 ++++ test/ForwardMode/MemberFunctions.C | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp index 6dfc8fa3f..c9a6fcb9e 100644 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ b/lib/Differentiator/ForwardModeVisitor.cpp @@ -762,6 +762,10 @@ namespace clad { } StmtDiff ForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { + //If there is no return value, we must not attempt to differentiate + if (!RS->getRetValue()) + return nullptr; + StmtDiff retValDiff = Visit(RS->getRetValue()); Stmt* returnStmt = nullptr; if (m_Mode == DiffMode::forward) { diff --git a/test/ForwardMode/MemberFunctions.C b/test/ForwardMode/MemberFunctions.C index 2d1c728a6..5982da5cd 100644 --- a/test/ForwardMode/MemberFunctions.C +++ b/test/ForwardMode/MemberFunctions.C @@ -28,6 +28,27 @@ public: // CHECK-NEXT: return (_d_this->x + _d_this->y) * i + _t0 * _d_i + (_d_i * j + i * _d_j) * j + _t1 * _d_j; // CHECK-NEXT: } + void mem_fn_with_void_return() { + return; + } + + // CHECK: void mem_fn_with_void_return_pushforward(SimpleFunctions *_d_this) { + // CHECK-NEXT:} + + double mem_fn_with_void_function_call(double i, double j) { + mem_fn_with_void_return(); + return i*j; + } + + // CHECK: double mem_fn_with_void_function_call_darg0(double i, double j) { + // CHECK-NEXT: double _d_i = 1; + // CHECK-NEXT: double _d_j = 0; + // CHECK-NEXT: SimpleFunctions _d_this_obj; + // CHECK-NEXT: SimpleFunctions *_d_this = &_d_this_obj; + // CHECK-NEXT: this->mem_fn_with_void_return_pushforward(_d_this); + // CHECK-NEXT: return _d_i * j + i * _d_j; + // CHECK-NEXT:} + double mem_fn_with_var_arg_list(double i, double j, ...) { return (x+y)*i + i*j*j; } @@ -727,6 +748,8 @@ int main() { TEST(mem_fn, 3, 5) // CHECK-EXEC: 30.00 // CHECK-EXEC: 33.00 + TEST(mem_fn_with_void_function_call, 3, 5) //CHECK-EXEC: 5.00 + TEST(mem_fn_with_var_arg_list, 3, 5) // CHECK-EXEC: 30.00 // CHECK-EXEC: 33.00