diff --git a/include/Math/Dual.hpp b/include/Math/Dual.hpp index c0692bd..deb594a 100644 --- a/include/Math/Dual.hpp +++ b/include/Math/Dual.hpp @@ -809,12 +809,13 @@ constexpr auto softplus(const Dual &x) -> Dual { } template constexpr auto log(const Dual &x) -> Dual { - return {log2(x.value()), x.gradient() / x.value()}; + constexpr double logof2 = 0.6931471805599453; // log(2); + return {log2(x.value()) * logof2, x.gradient() / x.value()}; } template constexpr auto log2(const Dual &x) -> Dual { - constexpr double log2 = 0.6931471805599453; // log(2); - return {log2(x.value()), x.gradient() / (log2 * x.value())}; + constexpr double logof2 = 0.6931471805599453; // log(2); + return {log2(x.value()), x.gradient() / (logof2 * x.value())}; } constexpr auto dval(double &x) -> double & { return x; } diff --git a/include/Math/Exp.hpp b/include/Math/Exp.hpp index 48afc07..ebc64fd 100644 --- a/include/Math/Exp.hpp +++ b/include/Math/Exp.hpp @@ -430,7 +430,8 @@ template constexpr auto smax(auto x, auto y) { template constexpr auto smax(auto x, auto y, auto z) { double m = std::max(std::max(value(x), value(y)), value(z)); - return m + log(exp(l * (x - m)) + exp(l * (y - m)) + exp(l * (z - m))) / l; + constexpr double f = l, i = 1 / f; + return m + log(exp(f * (x - m)) + exp(f * (y - m)) + exp(f * (z - m))) * i; } template constexpr auto smin(auto x, auto y) { return smax<-l>(x, y);