Skip to content

Commit

Permalink
Added MatPtr/MatPtrT/MatStorageT/MatStorage as a dynamically-sized re…
Browse files Browse the repository at this point in the history
…placement for CompressedArray.

Definition of array size is moved to the constructor.
Allocation is separate and parallelized.
All users of weights_raw.h migrated to CompressedWeights and weights_raw.h deleted.
Replaced all previous ForEachTensor functions with a single unified function.

PiperOrigin-RevId: 676813839
  • Loading branch information
theraysmith authored and copybara-github committed Oct 10, 2024
1 parent a570e3f commit b188aa7
Show file tree
Hide file tree
Showing 35 changed files with 1,573 additions and 1,386 deletions.
21 changes: 11 additions & 10 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ cc_library(
":allocator",
":threading",
"//compression:compress",
"//compression:sfp",
"@hwy//:algo",
"@hwy//:dot",
"@hwy//:hwy",
"@hwy//:math",
"@hwy//:matvec",
Expand Down Expand Up @@ -149,7 +147,6 @@ cc_test(
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
Expand Down Expand Up @@ -281,11 +278,9 @@ cc_library(
"//paligemma:image",
"@hwy//:hwy",
"@hwy//:bit_set",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)

Expand Down Expand Up @@ -481,6 +476,7 @@ cc_library(
":ops",
":prompt",
":weights",
"//compression:compress",
"@hwy//:dot",
"@hwy//:hwy", # base.h
"@hwy//:thread_pool",
Expand All @@ -498,9 +494,10 @@ cc_library(
deps = [
":allocator",
":common",
":gemma_lib",
":prompt",
"//compression:weights_raw",
":weights",
"//compression:compress",
"@hwy//:hwy",
],
)

Expand All @@ -512,13 +509,15 @@ cc_test(
"backprop/test_util.h",
],
deps = [
":allocator",
":backprop_scalar",
":common",
":gemma_lib",
":prompt",
":sampler",
":weights",
"@googletest//:gtest_main",
"//compression:weights_raw",
"//compression:compress",
"@hwy//:thread_pool",
],
)

Expand All @@ -534,15 +533,17 @@ cc_test(
"mem": "28g",
},
deps = [
":allocator",
":backprop",
":backprop_scalar",
":common",
":gemma_lib",
":ops",
":prompt",
":sampler",
":weights",
"@googletest//:gtest_main",
"//compression:weights_raw",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:thread_pool",
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 457c891775a7397bdb0376bb1031e6e027af1c48 EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bb6c3f36b0c8dde8a8ef98b0f0884f4de820a7ca EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)

## Note: absl needs to be installed by sentencepiece. This will only happen if
Expand All @@ -39,6 +39,7 @@ FetchContent_MakeAvailable(benchmark)
set(SOURCES
compression/blob_store.cc
compression/blob_store.h
compression/compress.cc
compression/compress.h
compression/compress-inl.h
compression/io_win.cc
Expand All @@ -48,7 +49,6 @@ set(SOURCES
compression/sfp-inl.h
compression/shared.h
compression/test_util-inl.h
compression/weights_raw.h
backprop/activations.h
backprop/backward.cc
backprop/backward.h
Expand Down
57 changes: 40 additions & 17 deletions backprop/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,72 @@

#include <array>

#include "compression/compress.h" // MatStorageT
#include "util/allocator.h" // ByteStorageT

namespace gcpp {

template <typename T, typename TConfig>
struct ForwardLayer {
ForwardLayer()
: input("input", kSeqLen, kModelDim),
pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim),
qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim),
att("att", kSeqLen * kHeads, kSeqLen),
att_out("att_out", kSeqLen * kHeads, kQKVDim),
att_post1("att_post1", kSeqLen, kModelDim),
attention_out("attention_out", kSeqLen, kModelDim),
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim),
ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2),
ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {}

static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
std::array<T, kSeqLen * kModelDim> input;
std::array<T, kSeqLen * kModelDim> pre_att_rms_out;
std::array<T, kSeqLen * (kHeads + 2) * kQKVDim> qkv;
std::array<T, kSeqLen * kHeads * kSeqLen> att;
std::array<T, kSeqLen * kHeads * kQKVDim> att_out;
std::array<T, kSeqLen * kModelDim> att_post1;
std::array<T, kSeqLen * kModelDim> attention_out;
std::array<T, kSeqLen * kModelDim> bf_pre_ffw_rms_out;
std::array<T, kSeqLen * kFFHiddenDim * 2> ffw_hidden;
std::array<T, kSeqLen * kFFHiddenDim> ffw_hidden_gated;

MatStorageT<T> input;
MatStorageT<T> pre_att_rms_out;
MatStorageT<T> qkv;
MatStorageT<T> att;
MatStorageT<T> att_out;
MatStorageT<T> att_post1;
MatStorageT<T> attention_out;
MatStorageT<T> bf_pre_ffw_rms_out;
MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated;
};

template <typename T, typename TConfig>
struct ForwardPass {
ForwardPass() {} // prevents placement-new calling memset
ForwardPass()
: final_layer_output("final_layer_output", kSeqLen, kModelDim),
final_norm_output("final_norm_output", kSeqLen, kModelDim),
logits("logits", kSeqLen, kVocabSize),
probs("probs", kSeqLen, kVocabSize) {
} // prevents placement-new calling memset

static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;

std::array<ForwardLayer<T, TConfig>, kLayers> layers;
std::array<T, kSeqLen * kModelDim> final_layer_output;
std::array<T, kSeqLen * kModelDim> final_norm_output;
std::array<T, kSeqLen * kVocabSize> logits;
std::array<T, kSeqLen * kVocabSize> probs;
MatStorageT<T> final_layer_output;
MatStorageT<T> final_norm_output;
MatStorageT<T> logits;
MatStorageT<T> probs;
};

template <typename TConfig>
struct AllocateForwardPass {
ByteStorageT operator()() const {
return AllocateSizeof<ForwardPass<float, TConfig>>();
ByteStorageT c_weights_u8 = AllocateSizeof<ForwardPass<float, TConfig>>();
auto* c_weights =
reinterpret_cast<ForwardPass<float, TConfig>*>(c_weights_u8.get());
new (c_weights) ForwardPass<float, TConfig>();
return c_weights_u8;
}
};

Expand All @@ -74,7 +97,7 @@ class ActivationsWrapper {
public:
ActivationsWrapper()
: data_(AllocateSizeof<WrappedT>()),
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
activations_(*(new(data_.get()) WrappedT())) {}

const WrappedT& get() const { return activations_; }
WrappedT& get() { return activations_; }
Expand Down
17 changes: 7 additions & 10 deletions backprop/backward-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ static HWY_NOINLINE void InputEmbeddingVJP(
}
}

template <typename TConfig, template <typename> typename LayerT>
void LayerVJP(const LayerT<TConfig>& weights,
template <typename TConfig, typename LayerT>
void LayerVJP(const LayerT& weights,
const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerT<TConfig>& grad, ForwardLayer<float, TConfig>& backward,
LayerT& grad, ForwardLayer<float, TConfig>& backward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
Expand Down Expand Up @@ -226,8 +226,7 @@ void LayerVJP(const LayerT<TConfig>& weights,
backward.attention_out.data() + pos * kModelDim, kModelDim);
}

hwy::ZeroBytes(backward.qkv.data(),
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
backward.qkv.ZeroInit();

MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
weights.attn_vec_einsum_w.data(), forward.att_out.data(),
Expand Down Expand Up @@ -343,12 +342,10 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
}
}

template <typename TConfig, template <typename...> typename WeightsT,
template <typename> typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const WeightsT<TConfig>& weights,
template <typename TConfig, typename WeightsT, typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights,
const ForwardPass<float, TConfig>& forward,
WeightsT<TConfig>& grad,
WeightsT& grad,
ForwardPass<float, TConfig>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
Expand Down
3 changes: 2 additions & 1 deletion backprop/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
CrossEntropyLossBackwardPass<TConfig, CompressedWeights<TConfig>,
CompressedLayer<TConfig>>(
prompt, weights, forward, grad, backward, inv_timescale, pool);
}

Expand Down
20 changes: 9 additions & 11 deletions backprop/backward_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "compression/weights_raw.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"

namespace gcpp {
template<typename T>
Expand Down Expand Up @@ -199,13 +199,11 @@ void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
}
}

template<typename T, typename TConfig>
void LayerVJP(const Layer<T, TConfig>& weights,
const ForwardLayer<T, TConfig>& forward,
const T* dy,
Layer<T, TConfig>& grad,
ForwardLayer<T, TConfig>& backward,
size_t num_tokens) {
template <typename T, typename TConfig>
void LayerVJP(const CompressedLayer<TConfig>& weights,
const ForwardLayer<T, TConfig>& forward, const T* dy,
CompressedLayer<TConfig>& grad,
ForwardLayer<T, TConfig>& backward, size_t num_tokens) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
Expand Down Expand Up @@ -298,11 +296,11 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
}
}

template<typename T, typename TConfig>
template <typename T, typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<T, TConfig>& weights,
const CompressedWeights<TConfig>& weights,
const ForwardPass<T, TConfig>& forward,
Weights<T, TConfig>& grad,
CompressedWeights<TConfig>& grad,
ForwardPass<T, TConfig>& backward) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
Expand Down
Loading

0 comments on commit b188aa7

Please sign in to comment.