Skip to content

Commit

Permalink
feat: new device support phi
Browse files Browse the repository at this point in the history
  • Loading branch information
rtp-llm committed Jun 21, 2024
1 parent c6a9861 commit afeb2b3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
37 changes: 35 additions & 2 deletions maga_transformer/cpp/models/GptModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ void GptModel::prepareAttentionInputs(
attention_inputs.attention_mask = inputs.attention_mask;
}


/*
* ┌───────────┐
* │ hidden │
* └─────┬─────┘
* │
* │
* ┌───────▼───────┐
* │ pre_layernorm?├─────────┐
* └───────┬───────┘ │
* │ │
* ┌─────▼─────┐ │
* │ attention │ │
* └─────┬─────┘ │
* │ │
* ┌───────▼───────┐ │
* ┌──────┤post_attn_norm?◄─────────┘
* │ └───────┬───────┘
* │ │
* │ ┌────▼────┐
* │ │ mlp │
* │ └────┬────┘
* │ │
* │ ┌────▼────┐
* └─────────► add │
* └────┬────┘
* │
* ┌─────▼─────┐
* │ layernorm │
* └───────────┘
*/
GptModelOutputs GptModel::forward(const GptModelInputs& inputs) {
const auto norm_type = description_.norm_type;
const auto norm_eps = description_.layernorm_eps;
Expand Down Expand Up @@ -163,6 +194,7 @@ GptModelOutputs GptModel::forward(const GptModelInputs& inputs) {

auto attn_out_buf = device_->allocateBuffer({hidden->type(), hidden->shape()}, {"attn_out_buf"});
auto residual = hidden;
BufferPtr residual2 = nullptr;
if (layer.pre_layernorm) {
residual = device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
device_->layernorm(LayernormParams(
Expand Down Expand Up @@ -205,7 +237,7 @@ GptModelOutputs GptModel::forward(const GptModelInputs& inputs) {
residual = attn_hidden;
}
} else {
hidden = move(attn_hidden);
residual2 = attn_hidden;
}

printBufferData(*hidden, "layer_" + to_string(i) + "_ffn_input");
Expand All @@ -226,7 +258,8 @@ GptModelOutputs GptModel::forward(const GptModelInputs& inputs) {
*hidden, *hidden, nullopt,
norm_type, ft::mayGetRef(layer.post_ffn_layernorm), norm_eps,
device_props_.ffn_fuse_add_residual ? nullopt : (OptionalConstBufferRef)*residual,
nullopt, ft::mayGetRef(layer.ffn_weights.down_weight->bias)));
(residual2 == nullptr) ? nullopt : (OptionalConstBufferRef)*residual2,
ft::mayGetRef(layer.ffn_weights.down_weight->bias)));

printBufferData(*hidden, "layer_" + to_string(i) + "_final_hidden");
}
Expand Down
7 changes: 4 additions & 3 deletions src/fastertransformer/devices/base_impl/FfnLayer.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "src/fastertransformer/devices/DeviceBase.h"
#include "src/fastertransformer/devices/OpData.h"

#include "src/fastertransformer/devices/utils/DebugUtils.h"
using namespace std;

namespace fastertransformer {
Expand Down Expand Up @@ -46,7 +46,7 @@ FfnLayerOutput DeviceBase::ffnLayer(const FfnLayerParams& params) {
std::nullopt,
*(params.weights.up_weight),
std::nullopt});

printBufferData(*up_output.output, "ffn_up");
if (FFNDispatch::dispatch(params) == FFNDispatch::FFNType::Gate) {
{
auto gate_output = loraLinear({params.input,
Expand All @@ -73,11 +73,12 @@ FfnLayerOutput DeviceBase::ffnLayer(const FfnLayerParams& params) {
mayGetRef(params.weights.up_weight->bias),
std::nullopt,
std::nullopt});

printBufferData(*up_output.output, "ffn_act");
auto output = loraLinear({*(up_output.output),
std::nullopt,
*(params.weights.down_weight),
std::nullopt});
printBufferData(*output.output, "ffn_out");
return FfnLayerOutput({move(output.output)});
} else {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/devices/cuda_impl/CudaLayernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LayernormOutput CudaDevice::layernorm(const LayernormParams& params) {
n,
stream_
);
} else if (params.bias.has_value() || params.residual1.has_value()) {
} else if (params.bias.has_value() || params.residual1.has_value() || params.residual2.has_value()) {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(data_type, invokeAddBiasResidual,
output.data(),
input.data(),
Expand Down

0 comments on commit afeb2b3

Please sign in to comment.