Skip to content

Commit

Permalink
log fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Feb 16, 2024
1 parent 5738844 commit 5f94615
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
7 changes: 4 additions & 3 deletions include/Math/Dual.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,12 +809,13 @@ constexpr auto softplus(const Dual<T, N> &x) -> Dual<T, N> {
}
template <class T, ptrdiff_t N>
constexpr auto log(const Dual<T, N> &x) -> Dual<T, N> {
return {log2(x.value()), x.gradient() / x.value()};
constexpr double logof2 = 0.6931471805599453; // log(2);
return {log2(x.value()) * logof2, x.gradient() / x.value()};
}
template <class T, ptrdiff_t N>
constexpr auto log2(const Dual<T, N> &x) -> Dual<T, N> {
constexpr double log2 = 0.6931471805599453; // log(2);
return {log2(x.value()), x.gradient() / (log2 * x.value())};
constexpr double logof2 = 0.6931471805599453; // log(2);
return {log2(x.value()), x.gradient() / (logof2 * x.value())};
}

constexpr auto dval(double &x) -> double & { return x; }
Expand Down
3 changes: 2 additions & 1 deletion include/Math/Exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ template <int l = 8> constexpr auto smax(auto x, auto y) {

template <int l = 8> constexpr auto smax(auto x, auto y, auto z) {
double m = std::max(std::max(value(x), value(y)), value(z));
return m + log(exp(l * (x - m)) + exp(l * (y - m)) + exp(l * (z - m))) / l;
constexpr double f = l, i = 1 / f;
return m + log(exp(f * (x - m)) + exp(f * (y - m)) + exp(f * (z - m))) * i;
}
template <int l = 8> constexpr auto smin(auto x, auto y) {
return smax<-l>(x, y);
Expand Down

0 comments on commit 5f94615

Please sign in to comment.