Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MatPtr/MatPtrT/MatStorageT/MatStorage as a dynamically-sized replacement for CompressedArray. #417

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ cc_library(
":allocator",
":threading",
"//compression:compress",
"//compression:sfp",
"@hwy//:algo",
"@hwy//:dot",
"@hwy//:hwy",
"@hwy//:math",
"@hwy//:matvec",
Expand Down Expand Up @@ -149,7 +147,6 @@ cc_test(
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
Expand Down Expand Up @@ -281,11 +278,9 @@ cc_library(
"//paligemma:image",
"@hwy//:hwy",
"@hwy//:bit_set",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)

Expand Down Expand Up @@ -481,6 +476,7 @@ cc_library(
":ops",
":prompt",
":weights",
"//compression:compress",
"@hwy//:dot",
"@hwy//:hwy", # base.h
"@hwy//:thread_pool",
Expand All @@ -498,9 +494,10 @@ cc_library(
deps = [
":allocator",
":common",
":gemma_lib",
":prompt",
"//compression:weights_raw",
":weights",
"//compression:compress",
"@hwy//:hwy",
],
)

Expand All @@ -512,13 +509,15 @@ cc_test(
"backprop/test_util.h",
],
deps = [
":allocator",
":backprop_scalar",
":common",
":gemma_lib",
":prompt",
":sampler",
":weights",
"@googletest//:gtest_main",
"//compression:weights_raw",
"//compression:compress",
"@hwy//:thread_pool",
],
)

Expand All @@ -534,15 +533,17 @@ cc_test(
"mem": "28g",
},
deps = [
":allocator",
":backprop",
":backprop_scalar",
":common",
":gemma_lib",
":ops",
":prompt",
":sampler",
":weights",
"@googletest//:gtest_main",
"//compression:weights_raw",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:thread_pool",
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 457c891775a7397bdb0376bb1031e6e027af1c48 EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bb6c3f36b0c8dde8a8ef98b0f0884f4de820a7ca EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)

## Note: absl needs to be installed by sentencepiece. This will only happen if
Expand All @@ -39,6 +39,7 @@ FetchContent_MakeAvailable(benchmark)
set(SOURCES
compression/blob_store.cc
compression/blob_store.h
compression/compress.cc
compression/compress.h
compression/compress-inl.h
compression/io_win.cc
Expand All @@ -48,7 +49,6 @@ set(SOURCES
compression/sfp-inl.h
compression/shared.h
compression/test_util-inl.h
compression/weights_raw.h
backprop/activations.h
backprop/backward.cc
backprop/backward.h
Expand Down
57 changes: 40 additions & 17 deletions backprop/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,72 @@

#include <array>

#include "compression/compress.h" // MatStorageT
#include "util/allocator.h" // ByteStorageT

namespace gcpp {

template <typename T, typename TConfig>
struct ForwardLayer {
ForwardLayer()
: input("input", kSeqLen, kModelDim),
pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim),
qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim),
att("att", kSeqLen * kHeads, kSeqLen),
att_out("att_out", kSeqLen * kHeads, kQKVDim),
att_post1("att_post1", kSeqLen, kModelDim),
attention_out("attention_out", kSeqLen, kModelDim),
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim),
ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2),
ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {}

static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
std::array<T, kSeqLen * kModelDim> input;
std::array<T, kSeqLen * kModelDim> pre_att_rms_out;
std::array<T, kSeqLen * (kHeads + 2) * kQKVDim> qkv;
std::array<T, kSeqLen * kHeads * kSeqLen> att;
std::array<T, kSeqLen * kHeads * kQKVDim> att_out;
std::array<T, kSeqLen * kModelDim> att_post1;
std::array<T, kSeqLen * kModelDim> attention_out;
std::array<T, kSeqLen * kModelDim> bf_pre_ffw_rms_out;
std::array<T, kSeqLen * kFFHiddenDim * 2> ffw_hidden;
std::array<T, kSeqLen * kFFHiddenDim> ffw_hidden_gated;

MatStorageT<T> input;
MatStorageT<T> pre_att_rms_out;
MatStorageT<T> qkv;
MatStorageT<T> att;
MatStorageT<T> att_out;
MatStorageT<T> att_post1;
MatStorageT<T> attention_out;
MatStorageT<T> bf_pre_ffw_rms_out;
MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated;
};

template <typename T, typename TConfig>
struct ForwardPass {
ForwardPass() {} // prevents placement-new calling memset
ForwardPass()
: final_layer_output("final_layer_output", kSeqLen, kModelDim),
final_norm_output("final_norm_output", kSeqLen, kModelDim),
logits("logits", kSeqLen, kVocabSize),
probs("probs", kSeqLen, kVocabSize) {
} // prevents placement-new calling memset

static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;

std::array<ForwardLayer<T, TConfig>, kLayers> layers;
std::array<T, kSeqLen * kModelDim> final_layer_output;
std::array<T, kSeqLen * kModelDim> final_norm_output;
std::array<T, kSeqLen * kVocabSize> logits;
std::array<T, kSeqLen * kVocabSize> probs;
MatStorageT<T> final_layer_output;
MatStorageT<T> final_norm_output;
MatStorageT<T> logits;
MatStorageT<T> probs;
};

template <typename TConfig>
struct AllocateForwardPass {
ByteStorageT operator()() const {
return AllocateSizeof<ForwardPass<float, TConfig>>();
ByteStorageT c_weights_u8 = AllocateSizeof<ForwardPass<float, TConfig>>();
auto* c_weights =
reinterpret_cast<ForwardPass<float, TConfig>*>(c_weights_u8.get());
new (c_weights) ForwardPass<float, TConfig>();
return c_weights_u8;
}
};

Expand All @@ -74,7 +97,7 @@ class ActivationsWrapper {
public:
ActivationsWrapper()
: data_(AllocateSizeof<WrappedT>()),
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
activations_(*(new(data_.get()) WrappedT())) {}

const WrappedT& get() const { return activations_; }
WrappedT& get() { return activations_; }
Expand Down
17 changes: 7 additions & 10 deletions backprop/backward-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ static HWY_NOINLINE void InputEmbeddingVJP(
}
}

template <typename TConfig, template <typename> typename LayerT>
void LayerVJP(const LayerT<TConfig>& weights,
template <typename TConfig, typename LayerT>
void LayerVJP(const LayerT& weights,
const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerT<TConfig>& grad, ForwardLayer<float, TConfig>& backward,
LayerT& grad, ForwardLayer<float, TConfig>& backward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
Expand Down Expand Up @@ -226,8 +226,7 @@ void LayerVJP(const LayerT<TConfig>& weights,
backward.attention_out.data() + pos * kModelDim, kModelDim);
}

hwy::ZeroBytes(backward.qkv.data(),
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
backward.qkv.ZeroInit();

MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
weights.attn_vec_einsum_w.data(), forward.att_out.data(),
Expand Down Expand Up @@ -343,12 +342,10 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
}
}

template <typename TConfig, template <typename...> typename WeightsT,
template <typename> typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const WeightsT<TConfig>& weights,
template <typename TConfig, typename WeightsT, typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights,
const ForwardPass<float, TConfig>& forward,
WeightsT<TConfig>& grad,
WeightsT& grad,
ForwardPass<float, TConfig>& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
Expand Down
3 changes: 2 additions & 1 deletion backprop/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
CrossEntropyLossBackwardPass<TConfig, CompressedWeights<TConfig>,
CompressedLayer<TConfig>>(
prompt, weights, forward, grad, backward, inv_timescale, pool);
}

Expand Down
20 changes: 9 additions & 11 deletions backprop/backward_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#include "backprop/activations.h"
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "compression/weights_raw.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"

namespace gcpp {
template<typename T>
Expand Down Expand Up @@ -199,13 +199,11 @@ void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
}
}

template<typename T, typename TConfig>
void LayerVJP(const Layer<T, TConfig>& weights,
const ForwardLayer<T, TConfig>& forward,
const T* dy,
Layer<T, TConfig>& grad,
ForwardLayer<T, TConfig>& backward,
size_t num_tokens) {
template <typename T, typename TConfig>
void LayerVJP(const CompressedLayer<TConfig>& weights,
const ForwardLayer<T, TConfig>& forward, const T* dy,
CompressedLayer<TConfig>& grad,
ForwardLayer<T, TConfig>& backward, size_t num_tokens) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
Expand Down Expand Up @@ -298,11 +296,11 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
}
}

template<typename T, typename TConfig>
template <typename T, typename TConfig>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<T, TConfig>& weights,
const CompressedWeights<TConfig>& weights,
const ForwardPass<T, TConfig>& forward,
Weights<T, TConfig>& grad,
CompressedWeights<TConfig>& grad,
ForwardPass<T, TConfig>& backward) {
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
Expand Down
Loading
Loading