Skip to content

Commit

Permalink
Consider array parameters differentiable in forward mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Oct 30, 2024
1 parent cddc21d commit dfde3cf
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 21 deletions.
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ template <typename T> class array {

/// Returns the size of the underlying array
CUDA_HOST_DEVICE std::size_t size() const { return m_size; }
/// Extends the size of array to `size` and default-initializer the new
/// elements if the current array size is less than `size`.
CUDA_HOST_DEVICE void extend(std::size_t size) {
if (size > m_size) {
// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
T* extendedArr = new T[size];
for (std::size_t i = 0; i < m_size; ++i)
extendedArr[i] = m_arr[i];
for (std::size_t i = m_size; i < size; ++i)
extendedArr[i] = T();
delete m_arr;
m_arr = extendedArr;
m_size = size;
}
}
/// Iterator functions
CUDA_HOST_DEVICE T* begin() { return m_arr; }
CUDA_HOST_DEVICE const T* begin() const { return m_arr; }
Expand Down Expand Up @@ -446,6 +461,20 @@ operator/(const array<T>& arr1, const array<U>& arr2) {
arr2);
}

namespace custom_derivatives {
namespace class_functions {
template <typename T>
void extend_reverse_forw(array<T>* arr, std::size_t size, array<T>* d_arr,
std::size_t d_size) {
arr->extend(size);
d_arr->extend(size);
}
template <typename T>
void extend_pullback(array<T>* arr, std::size_t size, array<T>* d_arr,
std::size_t* d_size) {}
} // namespace class_functions
} // namespace custom_derivatives

} // namespace clad

#endif // CLAD_ARRAY_H
4 changes: 4 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ namespace clad {
clang::SourceLocation srcLoc);

clang::QualType DetermineCladArrayValueType(clang::QualType T);
/// Extend the size of `arr` to safely access the element corresponding to
/// `idx`. Works only for clad::array when handling array parameters in
/// forward mode.
void EmitCladArrayExtend(StmtDiff arr, clang::Expr* idx);

/// Returns clad::Identify template declaration.
clang::TemplateDecl* GetCladConstructorPushforwardTag();
Expand Down
39 changes: 29 additions & 10 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,31 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
// non-reference type for creating the derivatives.
QualType dParamType = param->getType().getNonReferenceType();
// We do not create derived variable for array/pointer parameters.
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType) ||
utils::isArrayOrPointerType(dParamType))
if (!BaseForwardModeVisitor::IsDifferentiableType(dParamType))
continue;
Expr* dParam = nullptr;
bool isArrayTy = utils::isArrayOrPointerType(dParamType);
if (dParamType->isRealType()) {
// If param is independent variable, its derivative is 1, otherwise 0.
int dValue = (param == m_IndependentVar);
dParam = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context,
dValue);
} else if (isArrayTy) {
if (param == m_IndependentVar)
continue;

if (const auto* DT = dyn_cast<DecayedType>(dParamType)) {
if (const auto* CAT =
dyn_cast<ConstantArrayType>(DT->getOriginalType())) {
Expr* zero = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context, /*val=*/0);
dParam = m_Sema.ActOnInitList(noLoc, {zero}, noLoc).get();
dParamType = QualType::getFromOpaquePtr(CAT);
}
} else {
dParamType = GetCladArrayOfType(utils::GetValueType(dParamType));
dParam = getZeroInit(dParamType);
}
}
// For each function arg, create a variable _d_arg to store derivatives
// of potential reassignments, e.g.:
Expand All @@ -249,6 +265,10 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
BuildVarDecl(dParamType, "_d_" + param->getNameAsString(), dParam);
addToCurrentBlock(BuildDeclStmt(dParamDecl));
dParam = BuildDeclRef(dParamDecl);
if (!isa<ConstantArrayType>(dParamType) && isArrayTy) {
llvm::SmallVector<Expr*, 0> noParams{};
dParam = BuildCallExprToMemFn(dParam, "ptr", noParams);
}
if (dParamType->isRecordType() && param == m_IndependentVar) {
llvm::SmallVector<llvm::StringRef, 4> ref(diffVarInfo.fields.begin(),
diffVarInfo.fields.end());
Expand Down Expand Up @@ -984,7 +1004,6 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
VD = DRE->getDecl();
}
if (VD == m_IndependentVar) {
llvm::APSInt index;
Expr* diffExpr = nullptr;
Expr::EvalResult res;
Expr::SideEffectsKind AllowSideEffects =
Expand All @@ -1009,12 +1028,10 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
return StmtDiff(cloned, zero);

Expr* target = it->second;
// FIXME: fix when adding array inputs
if (!isArrayOrPointerType(target->getType()))
return StmtDiff(cloned, zero);
// llvm::APSInt IVal;
// if (!I->EvaluateAsInt(IVal, m_Context))
// return;
// The size of array parameters is unknown
// so we need to always extend the adjoint size before accessing the element.
if (utils::isArrayOrPointerType(clonedBase->getType()))
EmitCladArrayExtend({clonedBase, target}, clonedIndices.back());
// Create the _result[idx] expression.
auto result_at_is = BuildArraySubscript(target, clonedIndices);
return StmtDiff(cloned, result_at_is);
Expand Down Expand Up @@ -1366,8 +1383,10 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
opKind == UnaryOperatorKind::UO_Imag) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_Deref) {
if (Expr* dx = diff.getExpr_dx())
if (Expr* dx = diff.getExpr_dx()) {
EmitCladArrayExtend(diff, getZeroInit(m_Context.IntTy));
return StmtDiff(op, BuildOp(opKind, dx));
}
QualType literalTy =
utils::GetValueType(UnOp->getSubExpr()->getType()->getPointeeType());
return StmtDiff(
Expand Down
46 changes: 43 additions & 3 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,11 +853,14 @@ namespace clad {
derivedL = LDiff.getExpr_dx();
derivedR = RDiff.getExpr_dx();
if (utils::isArrayOrPointerType(LDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(RDiff.getExpr()->getType()))
!utils::isArrayOrPointerType(RDiff.getExpr()->getType())) {
derivedR = RDiff.getExpr();
else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType()))
EmitCladArrayExtend(LDiff, derivedR);
} else if (utils::isArrayOrPointerType(RDiff.getExpr()->getType()) &&
!utils::isArrayOrPointerType(LDiff.getExpr()->getType())) {
derivedL = LDiff.getExpr();
EmitCladArrayExtend(RDiff, derivedL);
}
}

Stmt* VisitorBase::GetCladZeroInit(llvm::MutableArrayRef<Expr*> args) {
Expand Down Expand Up @@ -896,4 +899,41 @@ namespace clad {
VisitorBase::GetCladConstructorReverseForwTagOfType(clang::QualType T) {
return InstantiateTemplate(GetCladConstructorReverseForwTag(), {T});
}

void VisitorBase::EmitCladArrayExtend(StmtDiff arr, Expr* idx) {
// FIXME: For now, only forward mode supports not differentiating w.r.t.
// array parameters.
if (m_DiffReq.Mode != DiffMode::forward)
return;
if (isa<DeclRefExpr>(arr.getExpr()->IgnoreImplicit()))
if (auto* MCE =
dyn_cast<CXXMemberCallExpr>(arr.getExpr_dx()->IgnoreImplicit())) {
Expr* cladArr = MCE->getImplicitObjectArgument()->IgnoreImplicit();
if (isCladArrayType(cladArr->getType()) &&
MCE->getDirectCallee()->getNameAsString() == "ptr") {
Expr* size = nullptr;
Expr::EvalResult index;
// If it's possible to determine the index at compile time, generate
// the `extend` argument as a literal. This will help us avoid ugly
// code like
// ```
// _d_arr.extend(0 + 1);
// _d_arr[0] = ...;
// ```
if (idx->EvaluateAsInt(index, m_Context,
Expr::SideEffectsKind::SE_NoSideEffects)) {
size = ConstantFolder::synthesizeLiteral(
m_Context.IntTy, m_Context,
index.Val.getInt().getExtValue() + 1);
} else {
Expr* one = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
m_Context, /*val=*/1);
size = BuildOp(BO_Add, idx, one);
}
llvm::SmallVector<Expr*, 1> param{size};
Expr* extendCall = BuildCallExprToMemFn(cladArr, "extend", param);
addToCurrentBlock(extendCall);
}
}
}
} // end namespace clad
1 change: 1 addition & 0 deletions test/Arrays/ArrayInputsForwardMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ double numMultIndex(double* arr, size_t n, double x) {
}

// CHECK: double numMultIndex_darg2(double *arr, size_t n, double x) {
// CHECK-NEXT: clad::array<double> _d_arr = {};
// CHECK-NEXT: size_t _d_n = 0;
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: bool _d_flag = 0;
Expand Down
18 changes: 15 additions & 3 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,16 @@ float f_literal_args_func(float x, float y, float *z) {
printf("hello world ");
return x * f_literal_helper(0.5, 'a', z, nullptr);
}
// CHECK: clad::ValueAndPushforward<float, float> f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q);

// CHECK: float f_literal_args_func_darg0(float x, float y, float *z) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: clad::array<float> _d_z = {};
// CHECK-NEXT: printf("hello world ");
// CHECK-NEXT: float _t0 = f_literal_helper(0.5, 'a', z, nullptr);
// CHECK-NEXT: return _d_x * _t0 + x * 0.F;
// CHECK-NEXT: clad::ValueAndPushforward<float, float> _t0 = f_literal_helper_pushforward(0.5, 'a', z, nullptr, 0., 0, _d_z.ptr(), nullptr);
// CHECK-NEXT: float &_t1 = _t0.value;
// CHECK-NEXT: return _d_x * _t1 + x * _t0.pushforward;
// CHECK-NEXT: }

inline unsigned int getBin(double low, double high, double val, unsigned int numBins) {
Expand All @@ -162,8 +165,11 @@ float f_call_inline_fxn(float *params, float const *obs, float const *xlArr) {
// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins);

// CHECK: float f_call_inline_fxn_darg0_0(float *params, const float *obs, const float *xlArr) {
// CHECK-NEXT: clad::array<float> _d_obs = {};
// CHECK-NEXT: clad::array<float> _d_xlArr = {};
// CHECK-NEXT: clad::ValueAndPushforward<unsigned int, unsigned int> _t0 = getBin_pushforward(0., 1., params[0], 1, 0., 0., 1.F, 0);
// CHECK-NEXT: const float _d_t116 = 0.F;
// CHECK-NEXT: _d_xlArr.extend(_t0.value + 1);
// CHECK-NEXT: const float _d_t116 = *(_d_xlArr.ptr() + _t0.value);
// CHECK-NEXT: const float t116 = *(xlArr + _t0.value);
// CHECK-NEXT: return _d_t116 * params[0] + t116 * 1.F;
// CHECK-NEXT: }
Expand Down Expand Up @@ -214,6 +220,12 @@ int main () { // expected-no-diagnostics
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<float, float> f_literal_helper_pushforward(float x, char ch, float *p, float *q, float _d_x, char _d_ch, float *_d_p, float *_d_q) {
// CHECK-NEXT: if (ch == 'a')
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: return {-x * x, -_d_x * x + -x * _d_x};
// CHECK-NEXT: }

// CHECK: inline clad::ValueAndPushforward<unsigned int, unsigned int> getBin_pushforward(double low, double high, double val, unsigned int numBins, double _d_low, double _d_high, double _d_val, unsigned int _d_numBins) {
// CHECK-NEXT: double _t0 = (high - low);
// CHECK-NEXT: double _d_binWidth = ((_d_high - _d_low) * numBins - _t0 * _d_numBins) / (numBins * numBins);
Expand Down
8 changes: 6 additions & 2 deletions test/ForwardMode/Pointer.C
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ double fn9(double* params, const double *constants) {
}

// CHECK: double fn9_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: clad::array<double> _d_constants = {};
// CHECK-NEXT: _d_constants.extend(1);
// CHECK-NEXT: double _d_c0 = *_d_constants.ptr();
// CHECK-NEXT: double c0 = *constants;
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand All @@ -208,7 +210,9 @@ double fn10(double *params, const double *constants) {
}

// CHECK: double fn10_darg0_0(double *params, const double *constants) {
// CHECK-NEXT: double _d_c0 = 0.;
// CHECK-NEXT: clad::array<double> _d_constants = {};
// CHECK-NEXT: _d_constants.extend(1);
// CHECK-NEXT: double _d_c0 = *(_d_constants.ptr() + 0);
// CHECK-NEXT: double c0 = *(constants + 0);
// CHECK-NEXT: return 1. * c0 + params[0] * _d_c0;
// CHECK-NEXT: }
Expand Down
12 changes: 9 additions & 3 deletions test/ROOT/TFormula.C
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,30 @@ void TFormula_example_grad_1(Double_t* x, Double_t* p, Double_t* _d_p);
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_0(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -1.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.);
//CHECK-NEXT: return 0. * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (1. + 0. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_1(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 1.);
//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0. + 1. + 0.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

//CHECK: Double_t TFormula_example_darg1_2(Double_t *x, Double_t *p) {
//CHECK-NEXT: clad::array<Double_t> _d_x = {};
//CHECK-NEXT: _d_x.extend(1);
//CHECK-NEXT: {{double|Double_t}} _t0 = (p[0] + p[1] + p[2]);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t1 = clad::custom_derivatives{{(::std)?}}::TMath::Exp_pushforward(-p[0], -0.);
//CHECK-NEXT: clad::ValueAndPushforward<Double_t, Double_t> _t2 = clad::custom_derivatives{{(::std)?}}::TMath::Abs_pushforward(p[1], 0.);
//CHECK-NEXT: return 0. * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: return _d_x.ptr()[0] * _t0 + x[0] * (0. + 0. + 1.) + _t1.pushforward + _t2.pushforward;
//CHECK-NEXT: }

Double_t TFormula_hess1(Double_t *x, Double_t *p) {
Expand Down

0 comments on commit dfde3cf

Please sign in to comment.