Skip to content

Commit

Permalink
Add support for array arguments in vector mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Aug 5, 2023
1 parent 1d8f664 commit e685e2b
Show file tree
Hide file tree
Showing 12 changed files with 550 additions and 141 deletions.
12 changes: 6 additions & 6 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,44 +154,44 @@ template <typename T> class array {
return *this;
}
/// Initializes the clad::array to the given clad::array_ref
CUDA_HOST_DEVICE array<T>& operator=(array_ref<T>& arr) {
CUDA_HOST_DEVICE array<T>& operator=(const array_ref<T>& arr) {
assert(arr.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] = arr[i];
return *this;
}

template <typename U>
CUDA_HOST_DEVICE array<T>& operator=(array_ref<U>& arr) {
CUDA_HOST_DEVICE array<T>& operator=(const array_ref<U>& arr) {
assert(arr.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] = arr[i];
return *this;
}

/// Performs element wise addition
CUDA_HOST_DEVICE array<T>& operator+=(array_ref<T>& arr) {
CUDA_HOST_DEVICE array<T>& operator+=(const array_ref<T>& arr) {
assert(arr.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] += arr[i];
return *this;
}
/// Performs element wise subtraction
CUDA_HOST_DEVICE array<T>& operator-=(array_ref<T>& arr) {
CUDA_HOST_DEVICE array<T>& operator-=(const array_ref<T>& arr) {
assert(arr.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] -= arr[i];
return *this;
}
/// Performs element wise multiplication
CUDA_HOST_DEVICE array<T>& operator*=(array_ref<T>& arr) {
CUDA_HOST_DEVICE array<T>& operator*=(const array_ref<T>& arr) {
assert(arr.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] *= arr[i];
return *this;
}
/// Performs element wise division
CUDA_HOST_DEVICE array<T>& operator/=(array_ref<T>& arr) {
CUDA_HOST_DEVICE array<T>& operator/=(const array_ref<T>& arr) {
assert(arr.size() == m_size);
for (std::size_t i = 0; i < m_size; i++)
m_arr[i] /= arr[i];
Expand Down
114 changes: 113 additions & 1 deletion include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ template <typename T> class array_ref {
/// Constructor for clad::array types
CUDA_HOST_DEVICE array_ref(array<T>& a) : m_arr(a.ptr()), m_size(a.size()) {}

template <typename U> CUDA_HOST_DEVICE array_ref<T>& operator=(array<U>& a) {
template <typename U>
CUDA_HOST_DEVICE array_ref<T>& operator=(const array<U>& a) {
assert(m_size == a.size());
for (std::size_t i = 0; i < m_size; ++i)
m_arr[i] = a[i];
Expand All @@ -53,6 +54,7 @@ template <typename T> class array_ref {
/// Returns the reference to the location at the index of the underlying
/// array
CUDA_HOST_DEVICE T& operator[](std::size_t i) { return m_arr[i]; }
CUDA_HOST_DEVICE const T& operator[](std::size_t i) const { return m_arr[i]; }
/// Returns the reference to the underlying array
CUDA_HOST_DEVICE T& operator*() { return *m_arr; }

Expand Down Expand Up @@ -156,6 +158,116 @@ template <typename T> class array_ref {
}
};

/// Overloaded operators for clad::array_ref which returns a new clad::array
/// object.

/// Multiplies the arrays element wise
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator*(const array_ref<T>& Ar,
const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out addition "
"assignment");
array<T> C(Ar);
C *= Br;
return C;
}

/// Adds the arrays element wise
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator+(const array_ref<T>& Ar,
const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out addition "
"assignment");
array<T> C(Ar);
C += Br;
return C;
}

/// Subtracts the arrays element wise
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator-(const array_ref<T>& Ar,
const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out addition "
"assignment");
array<T> C(Ar);
C -= Br;
return C;
}

/// Divides the arrays element wise
template <typename T, typename U>
CUDA_HOST_DEVICE array<T> operator/(const array_ref<T>& Ar,
const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out addition "
"assignment");
array<T> C(Ar);
C /= Br;
return C;
}

/// Multiplies array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator*(const array_ref<T>& Ar, U a) {
array<T> C(Ar);
C *= a;
return C;
}

/// Multiplies array_ref by a scalar (reverse order)
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator*(U a, const array_ref<T>& Ar) {
return Ar * a;
}

/// Divides array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator/(const array_ref<T>& Ar, U a) {
array<T> C(Ar);
C /= a;
return C;
}

/// Adds array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator+(const array_ref<T>& Ar, U a) {
array<T> C(Ar);
C += a;
return C;
}

/// Adds array_ref by a scalar (reverse order)
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator+(U a, const array_ref<T>& Ar) {
return Ar + a;
}

/// Subtracts array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator-(const array_ref<T>& Ar, U a) {
array<T> C(Ar);
C -= a;
return C;
}

/// Subtracts array_ref by a scalar (reverse order)
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array<T> operator-(U a, const array_ref<T>& Ar) {
array<T> C(Ar.size(), a);
C -= Ar;
return C;
}

/// `array_ref<void>` specialisation is created to be used as a placeholder
/// type in the overloaded derived function. All `array_ref<T>` types are
/// implicitly convertible to `array_ref<void>` type.
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class BaseForwardModeVisitor

static bool IsDifferentiableType(clang::QualType T);

StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
virtual StmtDiff
VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
Expand Down
1 change: 0 additions & 1 deletion include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ namespace clad {
else
return "<invalid>";
}

void dump() const {
printf("The code is: \n%s\n", getCode());
}
Expand Down
4 changes: 3 additions & 1 deletion include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,10 @@ namespace clad {
using ExtractDerivedFnTraitsForwMode_t =
typename ExtractDerivedFnTraitsForwMode<F>::type;

// OutputVecParamType is used to deduce the type of derivative arguments
// for vector forward mode.
template <class T, class R> struct OutputVecParamType {
using type = typename std::add_pointer<R>::type;
using type = array_ref<typename std::remove_pointer<R>::type>;
};

template <class T, class R>
Expand Down
7 changes: 7 additions & 0 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
/// m_Variables map because all other intermediate variables will have
/// derivatives as vectors.
std::unordered_map<const clang::ValueDecl*, clang::Expr*> m_ParamVariables;
/// Expression for total number of independent variables. This also includes
/// the size of array independent variables which will be inferred from the
/// size of the corresponding clad array they provide at runtime for storing
/// the derivatives.
clang::Expr* m_IndVarCountExpr;

public:
VectorForwardModeVisitor(DerivativeBuilder& builder);
Expand Down Expand Up @@ -68,6 +73,8 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
/// For example: for size = 4, the returned expression is: {0, 0, 0, 0}
clang::Expr* getZeroInitListExpr(size_t size, clang::QualType type);

StmtDiff
VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD) override;
Expand Down
10 changes: 10 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ namespace clad {
clang::TemplateDecl* GetCladArrayDecl();
/// Create clad::array<T> type.
clang::QualType GetCladArrayOfType(clang::QualType T);
/// Find declaration of clad::matrix templated type.
clang::TemplateDecl* GetCladMatrixDecl();
/// Create clad::matrix<T> type.
clang::QualType GetCladMatrixOfType(clang::QualType T);
/// Creates the expression clad::matrix<T>::identity(Args) for the given
/// type and args.
clang::Expr*
BuildIdentityMatrixExpr(clang::QualType T,
llvm::MutableArrayRef<clang::Expr*> Args,
clang::SourceLocation Loc);
/// Creates the expression Base.size() for the given Base expr. The Base
/// expr must be of clad::array_ref<T> type
clang::Expr* BuildArrayRefSizeExpr(clang::Expr* Base);
Expand Down
Loading

0 comments on commit e685e2b

Please sign in to comment.