-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reimplement jacobians using the vectorized forward mode #1121
base: master
Are you sure you want to change the base?
Changes from all commits
3101eef
2300e36
e28b551
ad97b85
35b8a20
a499dde
b14cfeb
edf5cf1
6b97654
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,6 +121,45 @@ class array_expression { | |
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryDiv, RE>( | ||
*this, r); | ||
} | ||
// Operator overload for addition. | ||
template <typename L1, typename BinOp1, typename R1, typename L2, | ||
typename BinOp2, typename R2> | ||
array_expression<const array_expression<L1, BinOp1, R1>&, BinaryAdd, | ||
const array_expression<L2, BinOp2, R2>&> | ||
operator+(const array_expression<L2, BinOp2, R2>& r) const { | ||
return array_expression<const array_expression<L1, BinOp1, R1>&, BinaryAdd, | ||
const array_expression<L2, BinOp2, R2>&>(*this, r); | ||
} | ||
|
||
// Operator overload for multiplication. | ||
template <typename L1, typename BinOp1, typename R1, typename L2, | ||
typename BinOp2, typename R2> | ||
array_expression<const array_expression<L1, BinOp1, R1>&, BinarySub, | ||
const array_expression<L2, BinOp2, R2>&> | ||
operator*(const array_expression<L2, BinOp2, R2>& r) const { | ||
return array_expression<const array_expression<L1, BinOp1, R1>&, BinaryMul, | ||
const array_expression<L2, BinOp2, R2>&>(*this, r); | ||
} | ||
|
||
// Operator overload for subtraction. | ||
template <typename L1, typename BinOp1, typename R1, typename L2, | ||
typename BinOp2, typename R2> | ||
array_expression<const array_expression<L1, BinOp1, R1>&, BinarySub, | ||
const array_expression<L2, BinOp2, R2>&> | ||
operator-(const array_expression<L2, BinOp2, R2>& r) const { | ||
return array_expression<const array_expression<L1, BinOp1, R1>&, BinarySub, | ||
const array_expression<L2, BinOp2, R2>&>(*this, r); | ||
} | ||
|
||
// Operator overload for division. | ||
template <typename L1, typename BinOp1, typename R1, typename L2, | ||
typename BinOp2, typename R2> | ||
array_expression<const array_expression<L1, BinOp1, R1>&, BinaryDiv, | ||
const array_expression<L2, BinOp2, R2>&> | ||
operator/(const array_expression<L2, BinOp2, R2>& r) const { | ||
return array_expression<const array_expression<L1, BinOp1, R1>&, BinaryDiv, | ||
const array_expression<L2, BinOp2, R2>&>(*this, r); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise. |
||
}; | ||
|
||
// Operator overload for addition, when the right operand is an array_expression | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,14 @@ template <typename T> class array_ref { | |
m_size = a.size(); | ||
return *this; | ||
} | ||
template <typename L, typename BinaryOp, typename R> | ||
CUDA_HOST_DEVICE array_ref<T>& | ||
operator=(const array_expression<L, BinaryOp, R>& arr_exp) { | ||
assert(arr_exp.size() == m_size); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] = arr_exp[i]; | ||
return *this; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise. |
||
/// Returns the size of the underlying array | ||
constexpr CUDA_HOST_DEVICE std::size_t size() const { return m_size; } | ||
constexpr CUDA_HOST_DEVICE PUREFUNC T* ptr() const { return m_arr; } | ||
|
@@ -71,7 +79,7 @@ template <typename T> class array_ref { | |
// Arithmetic overloads | ||
/// Divides the arrays element wise | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator/=(array_ref<U>& Ar) { | ||
CUDA_HOST_DEVICE array_ref<T>& operator/=(const array_ref<U>& Ar) { | ||
assert(m_size == Ar.size() && "Size of both the array_refs must be equal " | ||
"for carrying out addition assignment"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
|
@@ -80,7 +88,7 @@ template <typename T> class array_ref { | |
} | ||
/// Multiplies the arrays element wise | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator*=(array_ref<U>& Ar) { | ||
CUDA_HOST_DEVICE array_ref<T>& operator*=(const array_ref<U>& Ar) { | ||
assert(m_size == Ar.size() && "Size of both the array_refs must be equal " | ||
"for carrying out addition assignment"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
|
@@ -89,7 +97,7 @@ template <typename T> class array_ref { | |
} | ||
/// Adds the arrays element wise | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator+=(array_ref<U>& Ar) { | ||
CUDA_HOST_DEVICE array_ref<T>& operator+=(const array_ref<U>& Ar) { | ||
assert(m_size == Ar.size() && "Size of both the array_refs must be equal " | ||
"for carrying out addition assignment"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
|
@@ -98,36 +106,76 @@ template <typename T> class array_ref { | |
} | ||
/// Subtracts the arrays element wise | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator-=(array_ref<U>& Ar) { | ||
CUDA_HOST_DEVICE array_ref<T>& operator-=(const array_ref<U>& Ar) { | ||
assert(m_size == Ar.size() && "Size of both the array_refs must be equal " | ||
"for carrying out addition assignment"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] -= Ar[i]; | ||
return *this; | ||
} | ||
/// Divides the elements of the array_ref by elements of the array | ||
template <typename U> CUDA_HOST_DEVICE array_ref<T>& operator/=(array<U>& A) { | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator/=(const array<U>& A) { | ||
assert(m_size == A.size() && "Size of arrays must be equal"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] /= A[i]; | ||
return *this; | ||
} | ||
/// Multiplies the elements of the array_ref by elements of the array | ||
template <typename U> CUDA_HOST_DEVICE array_ref<T>& operator*=(array<U>& A) { | ||
template <typename L, typename BinaryOp, typename R> | ||
CUDA_HOST_DEVICE array_ref<T>& | ||
operator*=(const array_expression<L, BinaryOp, R>& arr_exp) { | ||
assert(arr_exp.size() == m_size); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] *= arr_exp[i]; | ||
return *this; | ||
} | ||
/// Adds the elements of the array_ref by elements of the array | ||
template <typename L, typename BinaryOp, typename R> | ||
CUDA_HOST_DEVICE array_ref<T>& | ||
operator+=(const array_expression<L, BinaryOp, R>& arr_exp) { | ||
assert(arr_exp.size() == m_size); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] += arr_exp[i]; | ||
return *this; | ||
} | ||
/// Subtracts the elements of the array_ref by elements of the array | ||
template <typename L, typename BinaryOp, typename R> | ||
CUDA_HOST_DEVICE array_ref<T>& | ||
operator-=(const array_expression<L, BinaryOp, R>& arr_exp) { | ||
assert(arr_exp.size() == m_size); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] -= arr_exp[i]; | ||
return *this; | ||
} | ||
/// Divides the elements of the array_ref by elements of the array | ||
template <typename L, typename BinaryOp, typename R> | ||
CUDA_HOST_DEVICE array_ref<T>& | ||
operator/=(const array_expression<L, BinaryOp, R>& arr_exp) { | ||
assert(arr_exp.size() == m_size); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] /= arr_exp[i]; | ||
return *this; | ||
} | ||
/// Multiplies the elements of the array_ref by elements of the array | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator*=(const array<U>& A) { | ||
assert(m_size == A.size() && "Size of arrays must be equal"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] *= A[i]; | ||
return *this; | ||
} | ||
/// Adds the elements of the array_ref by elements of the array | ||
template <typename U> CUDA_HOST_DEVICE array_ref<T>& operator+=(array<U>& A) { | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator+=(const array<U>& A) { | ||
assert(m_size == A.size() && "Size of arrays must be equal"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] += A[i]; | ||
return *this; | ||
} | ||
/// Subtracts the elements of the array_ref by elements of the array | ||
template <typename U> CUDA_HOST_DEVICE array_ref<T>& operator-=(array<U>& A) { | ||
template <typename U> | ||
CUDA_HOST_DEVICE array_ref<T>& operator-=(const array<U>& A) { | ||
assert(m_size == A.size() && "Size of arrays must be equal"); | ||
for (std::size_t i = 0; i < m_size; i++) | ||
m_arr[i] -= A[i]; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -199,31 +199,31 @@ CUDA_HOST_DEVICE inline void __builtin_powf_pullback(float x, float exponent, | |
// FIXME: Add the rest of the __builtin_ routines for log, sqrt, abs, etc. | ||
|
||
namespace std { | ||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> abs_pushforward(T x, T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> abs_pushforward(T x, dT d_x) { | ||
if (x >= 0) | ||
return {x, d_x}; | ||
else | ||
return {-x, -d_x}; | ||
} | ||
|
||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> exp_pushforward(T x, T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> exp_pushforward(T x, dT d_x) { | ||
return {::std::exp(x), ::std::exp(x) * d_x}; | ||
} | ||
|
||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> sin_pushforward(T x, T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> sin_pushforward(T x, dT d_x) { | ||
return {::std::sin(x), ::std::cos(x) * d_x}; | ||
} | ||
|
||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> cos_pushforward(T x, T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> cos_pushforward(T x, dT d_x) { | ||
return {::std::cos(x), (-1) * ::std::sin(x) * d_x}; | ||
} | ||
|
||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> sqrt_pushforward(T x, T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> sqrt_pushforward(T x, dT d_x) { | ||
return {::std::sqrt(x), (((T)1) / (((T)2) * ::std::sqrt(x))) * d_x}; | ||
} | ||
|
||
|
@@ -232,9 +232,9 @@ CUDA_HOST_DEVICE ValueAndPushforward<T, T> floor_pushforward(T x, T /*d_x*/) { | |
return {::std::floor(x), (T)0}; | ||
} | ||
|
||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> atan2_pushforward(T y, T x, T d_y, | ||
T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> atan2_pushforward(T y, T x, dT d_y, | ||
dT d_x) { | ||
return {::std::atan2(y, x), | ||
-(y / ((x * x) + (y * y))) * d_x + x / ((x * x) + (y * y)) * d_y}; | ||
} | ||
|
@@ -246,8 +246,8 @@ CUDA_HOST_DEVICE void atan2_pullback(T y, T x, U d_z, T* d_y, T* d_x) { | |
*d_x += -(y / ((x * x) + (y * y))) * d_z; | ||
} | ||
|
||
template <typename T> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, T> acos_pushforward(T x, T d_x) { | ||
template <typename T, typename dT> | ||
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> acos_pushforward(T x, dT d_x) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these changes not good for a separate PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They can exist on their own but they will have no use. The types don't match in the vectorized fwd mode because those will be |
||
return {::std::acos(x), ((-1) / (::std::sqrt(1 - x * x))) * d_x}; | ||
} | ||
|
||
|
@@ -270,10 +270,31 @@ pow_pushforward(T1 x, T2 exponent, T1 d_x, T2 d_exponent) { | |
auto val = ::std::pow(x, exponent); | ||
auto derivative = (exponent * ::std::pow(x, exponent - 1)) * d_x; | ||
// Only add directional derivative of base^exp w.r.t exp if the directional | ||
// seed d_exponent is non-zero. This is required because if base is less than or | ||
// equal to 0, then log(base) is undefined, and therefore if user only requested | ||
// directional derivative of base^exp w.r.t base -- which is valid --, the result would | ||
// be undefined because as per C++ valid number + NaN * 0 = NaN. | ||
// seed d_exponent is non-zero. This is required because if base is less than | ||
// or equal to 0, then log(base) is undefined, and therefore if user only | ||
// requested directional derivative of base^exp w.r.t base -- which is valid | ||
// --, the result would be undefined because as per C++ valid number + NaN * 0 | ||
// = NaN. | ||
if (d_exponent) | ||
derivative += (::std::pow(x, exponent) * ::std::log(x)) * d_exponent; | ||
return {val, derivative}; | ||
} | ||
|
||
template <typename T1, typename T2> | ||
CUDA_HOST_DEVICE | ||
ValueAndPushforward<decltype(::std::pow(T1(), T2())), | ||
clad::array<decltype(::std::pow(T1(), T2()))>> | ||
pow_pushforward(T1 x, T2 exponent, clad::array<T1> d_x, | ||
clad::array<T2> d_exponent) { | ||
decltype(::std::pow(T1(), T2())) val = ::std::pow(x, exponent); | ||
clad::array<decltype(::std::pow(T1(), T2()))> derivative = | ||
(exponent * ::std::pow(x, exponent - 1)) * d_x; | ||
// Only add directional derivative of base^exp w.r.t exp if the directional | ||
// seed d_exponent is non-zero. This is required because if base is less than | ||
// or equal to 0, then log(base) is undefined, and therefore if user only | ||
// requested directional derivative of base^exp w.r.t base -- which is valid | ||
// --, the result would be undefined because as per C++ valid number + NaN * 0 | ||
// = NaN. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this overload? |
||
if (d_exponent) | ||
derivative += (::std::pow(x, exponent) * ::std::log(x)) * d_exponent; | ||
return {val, derivative}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we write tests for these? I think we generally test in
Misc
.