From dfde3cfea7dc53fcf6394e2eedc47fd0733c1eb4 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Tue, 27 Aug 2024 13:15:17 +0300 Subject: [PATCH] Consider array parameters differentiable in forward mode --- include/clad/Differentiator/Array.h | 29 ++++++++++++ include/clad/Differentiator/VisitorBase.h | 4 ++ lib/Differentiator/BaseForwardModeVisitor.cpp | 39 ++++++++++++---- lib/Differentiator/VisitorBase.cpp | 46 +++++++++++++++++-- test/Arrays/ArrayInputsForwardMode.C | 1 + test/FirstDerivative/CallArguments.C | 18 ++++++-- test/ForwardMode/Pointer.C | 8 +++- test/ROOT/TFormula.C | 12 +++-- 8 files changed, 136 insertions(+), 21 deletions(-) diff --git a/include/clad/Differentiator/Array.h b/include/clad/Differentiator/Array.h index eef7de54e..0a4de221f 100644 --- a/include/clad/Differentiator/Array.h +++ b/include/clad/Differentiator/Array.h @@ -108,6 +108,21 @@ template class array { /// Returns the size of the underlying array CUDA_HOST_DEVICE std::size_t size() const { return m_size; } + /// Extends the size of array to `size` and default-initializer the new + /// elements if the current array size is less than `size`. + CUDA_HOST_DEVICE void extend(std::size_t size) { + if (size > m_size) { + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + T* extendedArr = new T[size]; + for (std::size_t i = 0; i < m_size; ++i) + extendedArr[i] = m_arr[i]; + for (std::size_t i = m_size; i < size; ++i) + extendedArr[i] = T(); + delete m_arr; + m_arr = extendedArr; + m_size = size; + } + } /// Iterator functions CUDA_HOST_DEVICE T* begin() { return m_arr; } CUDA_HOST_DEVICE const T* begin() const { return m_arr; } @@ -446,6 +461,20 @@ operator/(const array& arr1, const array& arr2) { arr2); } +namespace custom_derivatives { +namespace class_functions { +template +void extend_reverse_forw(array* arr, std::size_t size, array* d_arr, + std::size_t d_size) { + arr->extend(size); + d_arr->extend(size); +} +template +void extend_pullback(array* arr, std::size_t size, array* d_arr, + std::size_t* d_size) {} +} // namespace class_functions +} // namespace custom_derivatives + } // namespace clad #endif // CLAD_ARRAY_H diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 210f82112..b7291dc16 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -615,6 +615,10 @@ namespace clad { clang::SourceLocation srcLoc); clang::QualType DetermineCladArrayValueType(clang::QualType T); + /// Extend the size of `arr` to safely access the element corresponding to + /// `idx`. Works only for clad::array when handling array parameters in + /// forward mode. + void EmitCladArrayExtend(StmtDiff arr, clang::Expr* idx); /// Returns clad::Identify template declaration. clang::TemplateDecl* GetCladConstructorPushforwardTag(); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 8015b8fdb..be60c3aeb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -229,15 +229,31 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { // non-reference type for creating the derivatives. QualType dParamType = param->getType().getNonReferenceType(); // We do not create derived variable for array/pointer parameters. - if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) || - utils::isArrayOrPointerType(dParamType)) + if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType)) continue; Expr* dParam = nullptr; + bool isArrayTy = utils::isArrayOrPointerType(dParamType); if (dParamType->isRealType()) { // If param is independent variable, its derivative is 1, otherwise 0. int dValue = (param == m_IndependentVar); dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, dValue); + } else if (isArrayTy) { + if (param == m_IndependentVar) + continue; + + if (const auto* DT = dyn_cast(dParamType)) { + if (const auto* CAT = + dyn_cast(DT->getOriginalType())) { + Expr* zero = ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, /*val=*/0); + dParam = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get(); + dParamType = QualType::getFromOpaquePtr(CAT); + } + } else { + dParamType = GetCladArrayOfType(utils::GetValueType(dParamType)); + dParam = getZeroInit(dParamType); + } } // For each function arg, create a variable _d_arg to store derivatives // of potential reassignments, e.g.: @@ -249,6 +265,10 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() { BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam); addToCurrentBlock(BuildDeclStmt(dParamDecl)); dParam = BuildDeclRef(dParamDecl); + if (!isa(dParamType) && isArrayTy) { + llvm::SmallVector noParams{}; + dParam = BuildCallExprToMemFn(dParam, "ptr", noParams); + } if (dParamType->isRecordType() && param == m_IndependentVar) { llvm::SmallVector ref(diffVarInfo.fields.begin(), diffVarInfo.fields.end()); @@ -984,7 +1004,6 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { VD = DRE->getDecl(); } if (VD == m_IndependentVar) { - llvm::APSInt index; Expr* diffExpr = nullptr; Expr::EvalResult res; Expr::SideEffectsKind AllowSideEffects = @@ -1009,12 +1028,10 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { return StmtDiff(cloned, zero); Expr* target = it->second; - // FIXME: fix when adding array inputs - if (!isArrayOrPointerType(target->getType())) - return StmtDiff(cloned, zero); - // llvm::APSInt IVal; - // if (!I->EvaluateAsInt(IVal, m_Context)) - // return; + // The size of array parameters is unknown + // so we need to always extend the adjoint size before accessing the element. + if (utils::isArrayOrPointerType(clonedBase->getType())) + EmitCladArrayExtend({clonedBase, target}, clonedIndices.back()); // Create the _result[idx] expression. auto result_at_is = BuildArraySubscript(target, clonedIndices); return StmtDiff(cloned, result_at_is); @@ -1366,8 +1383,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { opKind == UnaryOperatorKind::UO_Imag) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_Deref) { - if (Expr* dx = diff.getExpr_dx()) + if (Expr* dx = diff.getExpr_dx()) { + EmitCladArrayExtend(diff, getZeroInit(m_Context.IntTy)); return StmtDiff(op, BuildOp(opKind, dx)); + } QualType literalTy = utils::GetValueType(UnOp->getSubExpr()->getType()->getPointeeType()); return StmtDiff( diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index e8fce3628..5136191e0 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -853,11 +853,14 @@ namespace clad { derivedL = LDiff.getExpr_dx(); derivedR = RDiff.getExpr_dx(); if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) && - !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) + !utils::isArrayOrPointerType(RDiff.getExpr()->getType())) { derivedR = RDiff.getExpr(); - else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && - !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) + EmitCladArrayExtend(LDiff, derivedR); + } else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) && + !utils::isArrayOrPointerType(LDiff.getExpr()->getType())) { derivedL = LDiff.getExpr(); + EmitCladArrayExtend(RDiff, derivedL); + } } Stmt* VisitorBase::GetCladZeroInit(llvm::MutableArrayRef args) { @@ -896,4 +899,41 @@ namespace clad { VisitorBase::GetCladConstructorReverseForwTagOfType(clang::QualType T) { return InstantiateTemplate(GetCladConstructorReverseForwTag(), {T}); } + + void VisitorBase::EmitCladArrayExtend(StmtDiff arr, Expr* idx) { + // FIXME: For now, only forward mode supports not differentiating w.r.t. + // array parameters. + if (m_DiffReq.Mode != DiffMode::forward) + return; + if (isa(arr.getExpr()->IgnoreImplicit())) + if (auto* MCE = + dyn_cast(arr.getExpr_dx()->IgnoreImplicit())) { + Expr* cladArr = MCE->getImplicitObjectArgument()->IgnoreImplicit(); + if (isCladArrayType(cladArr->getType()) && + MCE->getDirectCallee()->getNameAsString() == "ptr") { + Expr* size = nullptr; + Expr::EvalResult index; + // If it's possible to determine the index at compile time, generate + // the `extend` argument as a literal. This will help us avoid ugly + // code like + // ``` + // _d_arr.extend(0 + 1); + // _d_arr[0] = ...; + // ``` + if (idx->EvaluateAsInt(index, m_Context, + Expr::SideEffectsKind::SE_NoSideEffects)) { + size = ConstantFolder::synthesizeLiteral( + m_Context.IntTy, m_Context, + index.Val.getInt().getExtValue() + 1); + } else { + Expr* one = ConstantFolder::synthesizeLiteral(m_Context.IntTy, + m_Context, /*val=*/1); + size = BuildOp(BO_Add, idx, one); + } + llvm::SmallVector param{size}; + Expr* extendCall = BuildCallExprToMemFn(cladArr, "extend", param); + addToCurrentBlock(extendCall); + } + } + } } // end namespace clad diff --git a/test/Arrays/ArrayInputsForwardMode.C b/test/Arrays/ArrayInputsForwardMode.C index f3c5a9060..fe79a17d5 100644 --- a/test/Arrays/ArrayInputsForwardMode.C +++ b/test/Arrays/ArrayInputsForwardMode.C @@ -57,6 +57,7 @@ double numMultIndex(double* arr, size_t n, double x) { } // CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) { +// CHECK-NEXT: clad::array _d_arr = {}; // CHECK-NEXT: size_t _d_n = 0; // CHECK-NEXT: double _d_x = 1; // CHECK-NEXT: bool _d_flag = 0; diff --git a/test/FirstDerivative/CallArguments.C b/test/FirstDerivative/CallArguments.C index b9219a220..dd1ac689d 100644 --- a/test/FirstDerivative/CallArguments.C +++ b/test/FirstDerivative/CallArguments.C @@ -140,13 +140,16 @@ float f_literal_args_func(float x, float y, float *z) { printf("hello world "); return x * f_literal_helper(0.5, 'a', z, nullptr); } +// CHECK: clad::ValueAndPushforward f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q); // CHECK: float f_literal_args_func_darg0(float x, float y, float *z) { // CHECK-NEXT: float _d_x = 1; // CHECK-NEXT: float _d_y = 0; +// CHECK-NEXT: clad::array _d_z = {}; // CHECK-NEXT: printf("hello world "); -// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr); -// CHECK-NEXT: return _d_x * _t0 + x * 0.F; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = f_literal_helper_pushforward(0.5, 'a', z, nullptr, 0., 0, _d_z.ptr(), nullptr); +// CHECK-NEXT: float &_t1 = _t0.value; +// CHECK-NEXT: return _d_x * _t1 + x * _t0.pushforward; // CHECK-NEXT: } inline unsigned int getBin(double low, double high, double val, unsigned int numBins) { @@ -162,8 +165,11 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) { // CHECK: inline clad::ValueAndPushforward getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins); // CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) { +// CHECK-NEXT: clad::array _d_obs = {}; +// CHECK-NEXT: clad::array _d_xlArr = {}; // CHECK-NEXT: clad::ValueAndPushforward _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0); -// CHECK-NEXT: const float _d_t116 = 0.F; +// CHECK-NEXT: _d_xlArr.extend(_t0.value + 1); +// CHECK-NEXT: const float _d_t116 = *(_d_xlArr.ptr() + _t0.value); // CHECK-NEXT: const float t116 = *(xlArr + _t0.value); // CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F; // CHECK-NEXT: } @@ -214,6 +220,12 @@ int main () { // expected-no-diagnostics // CHECK-NEXT: return {x * x, _d_x * x + x * _d_x}; // CHECK-NEXT: } +// CHECK: clad::ValueAndPushforward f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q) { +// CHECK-NEXT: if (ch == 'a') +// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x}; +// CHECK-NEXT: return {-x * x, -_d_x * x + -x * _d_x}; +// CHECK-NEXT: } + // CHECK: inline clad::ValueAndPushforward getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) { // CHECK-NEXT: double _t0 = (high - low); // CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins); diff --git a/test/ForwardMode/Pointer.C b/test/ForwardMode/Pointer.C index 909035078..a18c5c8d0 100644 --- a/test/ForwardMode/Pointer.C +++ b/test/ForwardMode/Pointer.C @@ -197,7 +197,9 @@ double fn9(double* params, const double *constants) { } // CHECK: double fn9_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0.; +// CHECK-NEXT: clad::array _d_constants = {}; +// CHECK-NEXT: _d_constants.extend(1); +// CHECK-NEXT: double _d_c0 = *_d_constants.ptr(); // CHECK-NEXT: double c0 = *constants; // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } @@ -208,7 +210,9 @@ double fn10(double *params, const double *constants) { } // CHECK: double fn10_darg0_0(double *params, const double *constants) { -// CHECK-NEXT: double _d_c0 = 0.; +// CHECK-NEXT: clad::array _d_constants = {}; +// CHECK-NEXT: _d_constants.extend(1); +// CHECK-NEXT: double _d_c0 = *(_d_constants.ptr() + 0); // CHECK-NEXT: double c0 = *(constants + 0); // CHECK-NEXT: return 1. * c0 + params[0] * _d_c0; // CHECK-NEXT: } diff --git a/test/ROOT/TFormula.C b/test/ROOT/TFormula.C index 2c2bdad3c..f2faf5711 100644 --- a/test/ROOT/TFormula.C +++ b/test/ROOT/TFormula.C @@ -53,24 +53,30 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p); //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_0(Double_t *x, Double_t *p) { +//CHECK-NEXT: clad::array _d_x = {}; +//CHECK-NEXT: _d_x.extend(1); //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -1.); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.); -//CHECK-NEXT: return 0. * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_1(Double_t *x, Double_t *p) { +//CHECK-NEXT: clad::array _d_x = {}; +//CHECK-NEXT: _d_x.extend(1); //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.); -//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } //CHECK: Double_t TFormula_example_darg1_2(Double_t *x, Double_t *p) { +//CHECK-NEXT: clad::array _d_x = {}; +//CHECK-NEXT: _d_x.extend(1); //CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]); //CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.); //CHECK-NEXT: clad::ValueAndPushforward _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.); -//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward; +//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward; //CHECK-NEXT: } Double_t TFormula_hess1(Double_t *x, Double_t *p) {