From a570e3f6624e8958327f824ccc8073c105abb2d5 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Wed, 9 Oct 2024 07:50:02 -0700 Subject: [PATCH] Reduce number of operations in Gelu() by one Mul. About 5% faster Gen.Activation. PiperOrigin-RevId: 684035719 --- ops/ops-inl.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ops/ops-inl.h b/ops/ops-inl.h index adb06fe..f03159a 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -65,16 +65,19 @@ StaticCast(From from) noexcept { } } +// We use the tanh approximation for gelu (also used in training). +// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) +// = 0.5 * x * (1 + tanh(x * (sqrt(2/π) + sqrt(2/π) * 0.044715 * x^2))) +// = 0.5 * x * (1 + tanh(x * (0.79788 + 0.035677 * x^2))) +// = x * (0.5 + 0.5 * tanh(x * (0.79788 + 0.035677 * x^2)))) template HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { - const hn::Vec kMul = hn::Set(d, 0.044715f); + const hn::Vec kMul = hn::Set(d, 0.03567740813636141f); const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); const hn::Vec kHalf = hn::Set(d, 0.5f); - // tanh approximation matches training. - const hn::Vec v3 = hn::Mul(hn::Mul(v, v), v); - const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); - // 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5). + const hn::Vec v2 = hn::Mul(v, v); + const hn::Vec arg = hn::Mul(v, hn::MulAdd(kMul, v2, kSqrt2OverPi)); const hn::Vec cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); return hn::Mul(v, cdf); }