Skip to content

Commit

Permalink
Minor cleanup:
Browse files Browse the repository at this point in the history
add app.h comment
separate sum-inl.h and basics.h headers
update include pragmas
update dot_test thresholds

PiperOrigin-RevId: 684398826
  • Loading branch information
jan-wassenberg authored and copybara-github committed Oct 10, 2024
1 parent a570e3f commit d8db853
Show file tree
Hide file tree
Showing 14 changed files with 295 additions and 206 deletions.
11 changes: 11 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ licenses(["notice"])

exports_files(["LICENSE"])

cc_library(
name = "basics",
hdrs = ["util/basics.h"],
deps = [
"@hwy//:hwy",
],
)

cc_library(
name = "allocator",
hdrs = ["util/allocator.h"],
Expand Down Expand Up @@ -66,13 +74,15 @@ cc_library(
],
textual_hdrs = [
"ops/dot-inl.h",
"ops/sum-inl.h",
"ops/fp_arith-inl.h",
"ops/matmul-inl.h",
"ops/matvec-inl.h",
"ops/ops-inl.h",
],
deps = [
":allocator",
":basics",
":threading",
"//compression:compress",
"//compression:sfp",
Expand Down Expand Up @@ -271,6 +281,7 @@ cc_library(
],
deps = [
":allocator",
":basics",
":common",
":ops",
":tokenizer",
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,13 @@ set(SOURCES
ops/matmul-inl.h
ops/matvec-inl.h
ops/ops-inl.h
ops/sum-inl.h
paligemma/image.cc
paligemma/image.h
util/allocator.h
util/app.h
util/args.h
util/basics.h
util/test_util.h
util/threading.h
)
Expand Down
2 changes: 1 addition & 1 deletion compression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ cc_library(
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
"//:allocator",
"//:basics",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],
Expand Down
2 changes: 1 addition & 1 deletion compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <cmath> // lroundf, only if COMPRESS_STATS

#include "compression/blob_store.h"
#include "compression/compress.h"
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
Expand Down
2 changes: 1 addition & 1 deletion compression/nuq-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include <stdio.h>

#include "compression/shared.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "hwy/base.h"

#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
Expand Down
3 changes: 2 additions & 1 deletion gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "paligemma/image.h"
#include "util/allocator.h"
#include "util/allocator.h" // RowVectorBatch
#include "util/basics.h" // TokenAndProb
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
Expand Down
2 changes: 1 addition & 1 deletion ops/dot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ class DotStats {

ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
// Extremely high error on aarch64.
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f);
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
}

// Backward relative error, lower is better.
Expand Down
5 changes: 3 additions & 2 deletions ops/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
#include <stddef.h>

#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h" // PerClusterPools
#include "util/threading.h"
#include "hwy/aligned_allocator.h" // IWYU pragma: export
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: export
#include "hwy/per_target.h"

namespace gcpp {
Expand Down
180 changes: 2 additions & 178 deletions ops/ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <type_traits> // std::enable_if_t

#include "compression/compress.h"
#include "util/allocator.h" // TokenAndProb
#include "util/basics.h" // TokenAndProb
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h"
Expand All @@ -44,6 +44,7 @@

#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/sum-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/math/math-inl.h"
#include "hwy/profiler.h" // also uses SIMD
Expand Down Expand Up @@ -507,183 +508,6 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size);
}

// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 sums.
struct SumKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using State = double;

// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
sum0 = hn::Add(sum0, w0);
sum1 = hn::Add(sum1, w1);
sum2 = hn::Add(sum2, w2);
sum3 = hn::Add(sum3, w3);
}

// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
// Reduce to two f32 sums so we can promote them to four f64 vectors.
VF sum02, sum13;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
VF unused0 = hn::Zero(df);
VF unused1 = hn::Zero(df);
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
} else {
// If not native, the multiplication costs extra, so convert to f32.
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fe1 = hn::PromoteEvenTo(df, w1);
const VF fe2 = hn::PromoteEvenTo(df, w2);
const VF fe3 = hn::PromoteEvenTo(df, w3);
const VF fo0 = hn::PromoteOddTo(df, w0);
const VF fo1 = hn::PromoteOddTo(df, w1);
const VF fo2 = hn::PromoteOddTo(df, w2);
const VF fo3 = hn::PromoteOddTo(df, w3);
const VF fe01 = hn::Add(fe0, fe1);
const VF fe23 = hn::Add(fe2, fe3);
const VF fo01 = hn::Add(fo0, fo1);
const VF fo23 = hn::Add(fo2, fo3);
sum02 = hn::Add(fe01, fe23);
sum13 = hn::Add(fo01, fo23);
}

const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, sum02);
const VS d1 = hn::PromoteUpperTo(ds, sum02);
const VS d2 = hn::PromoteLowerTo(ds, sum13);
const VS d3 = hn::PromoteUpperTo(ds, sum13);

sum0 = hn::Add(sum0, d0);
sum1 = hn::Add(sum1, d1);
sum2 = hn::Add(sum2, d2);
sum3 = hn::Add(sum3, d3);
}

// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
VR& comp0) const {
sum0 = hn::Add(sum0, w0);
}

// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
VS& extra0) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
VF f0;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
} else {
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fo0 = hn::PromoteOddTo(df, w0);
f0 = hn::Add(fe0, fo0);
}

const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, f0);
const VS d1 = hn::PromoteUpperTo(ds, f0);

sum0 = hn::Add(sum0, d0);
extra0 = hn::Add(extra0, d1);
}

template <class DState, class VS = hn::Vec<DState>>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& extra0, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, extra0); // from Update1
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};

// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
// instead of FastTwoSums because the magnitude of the initial sum is not
// always greater than the next input, and this does actually change the e2e
// generation results. Note that Kahan summation differs in that it first adds
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains.
//
// This about as accurate as SumKernelDouble but slower, hence we only use this
// if f64 is not supported on this target.
struct SumKernelCascaded {
template <typename VT, typename WT>
using Raw = float;
using State = float;

template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
VF& comp3) const {
VF serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, sum0, w0, serr0);
sum1 = TwoSums(df, sum1, w1, serr1);
sum2 = TwoSums(df, sum2, w2, serr2);
sum3 = TwoSums(df, sum3, w3, serr3);

comp0 = hn::Add(comp0, serr0);
comp1 = hn::Add(comp1, serr1);
comp2 = hn::Add(comp2, serr2);
comp3 = hn::Add(comp3, serr3);
}

template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF serr0;
sum0 = TwoSums(df, sum0, w0, serr0);

comp0 = hn::Add(comp0, serr0);
}

template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};

using SumKernelDefault =
hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;

template <class D, typename VT>
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
const hn::Repartition<Raw, D> d_raw;
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
}

// See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) {
Expand Down
Loading

0 comments on commit d8db853

Please sign in to comment.