Skip to content

Commit

Permalink
Merge pull request #266 from siemonchan/qwen
Browse files Browse the repository at this point in the history
PEFT支持
  • Loading branch information
ztxz16 authored Aug 18, 2023
2 parents c6db2e0 + 8a76427 commit feffae3
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 8 deletions.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,39 @@ new_model = llm.model("model.flm"); # 导入fastllm模型

注: 该功能处于测试阶段,目前仅验证了ChatGLM、ChatGLM2模型可以通过2行代码加速

## PEFT支持(测试中,目前仅支持ChatGLM + LoRA)

使用[🤗PEFT](https://huggingface.co/docs/peft/index)可以方便地运行finetune过的大模型,你可以使用如下的方式让你的PEFT模型使用fastllm加速:

```python
import sys
from peft import PeftModel
from transformers import AutoModel, AutoTokenizer
sys.path.append('..')
model = AutoModel.from_pretrained("THUDM/chatglm-6b", device_map='cpu', trust_remote_code=True)
model = PeftModel.from_pretrained(model, "path/to/your/own/adapter") # 这里使用你自己的peft adapter
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

# 如果模型中存在active_adapter,那么在fastllm模型中,这个adapter也会被默认启用
from fastllm_pytools import llm
model = llm.from_hf(model, tokenizer, dtype = "float16") # dtype支持 "float16", "int8", "int4"
```

接下来,你就可以像使用普通的模型一样(例如调用chat,stream_chat函数)

你也可以更换PEFT模型所使用的的adapter:

```python
model.set_adapter('your adapter name')
```

或者关闭PEFT,使用原本的预训练模型:

```python
model.disable_adapter()
```

## 推理速度

6B级int4模型单4090延迟最低约5.5ms
Expand Down
10 changes: 10 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ namespace fastllm {

std::map <std::string, Data> weight;

std::map <std::string, std::map <std::string, std::string>> peftDict;

std::set <std::string> embeddingNames;

void LoadFromFile(const std::string &fileName); // 从文件读取
Expand All @@ -394,6 +396,8 @@ namespace fastllm {

void AddDict(const std::string &key, const std::string &value); // 插入一个词条

void AddAdapterDict(const std::string &name, const std::string &key, const std::string &value);

void AddWeight(const std::string &key, const std::vector <int> &dims,
DataType dataType, WeightType weightType, DataType oriDataType, uint8_t *oriData); // 插入一个权重

Expand Down Expand Up @@ -479,6 +483,12 @@ namespace fastllm {
void SoftmaxBatch(std::vector <Data*> &input, std::vector <Data*> &output, int axis);

void CatDirectBatch(std::vector <Data*> &input0, std::vector <Data*> &input1, int axis);

void LoraLayer(Data &input, Data &weight, Data &loraA, Data &loraB, const Data &bias, Data &output,
std::map <std::string, std::string> loraConfig);

void IA3Layer(Data &input, Data &weight, Data &ia3_l, Data &bias, Data &output,
std::map <std::string, std::string> ia3Config);
}

#endif //TEST_FASTLLM_H
6 changes: 6 additions & 0 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ namespace fastllm {

virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) = 0; // 根据当前回复更新history

virtual void SetAdapter(const std::string &name);

virtual void DisableAdapter();

std::string model_type;

std::string pre_prompt; // 最初对话的提示语
Expand Down Expand Up @@ -146,5 +150,7 @@ namespace fastllm {
std::mutex mainLoopLocker, dictLocker;

std::map <std::string, int> deviceMap;

std::string adapterName;
};
}
10 changes: 8 additions & 2 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2104,8 +2104,14 @@ namespace fastllm {
float *input1Data = (float*)input1.cpuData;

int len = input0.Count(0);
for (int i = 0; i < len; i++) {
input0Data[i] *= input1Data[i];
int inner = input1.Count(0);
AssertInFastLLM(len % inner == 0, "MulTo error: Data`s shape can`t perform MulTo operation.\n");
int round = (len / inner);
for (int j = 0; j < round; j++) {
for (int i = 0; i < len; i++) {
input0Data[i] *= input1Data[i];
}
input0Data += inner;
}
}

Expand Down
86 changes: 86 additions & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,22 @@ namespace fastllm {
}
}

if (this->dicts.find("peft_siz") != this->dicts.end()) {
int peftSize = atoi(this->dicts["peft_size"].c_str());
for (int i = 0; i < peftSize; i++) {
std::string adapter_name = buffer.ReadString();
this->peftDict[adapter_name] = {};

int adapter_size = buffer.ReadInt();
for (int j = 0; j < adapter_size; j++) {
std::string key = buffer.ReadString();
std::string value = buffer.ReadString();
//printf("%s %s\n", key.c_str(), value.c_str());
this->peftDict[adapter_name][key] = value;
}
}
}

bool useScore = this->dicts["tokenizer_use_score"] == "1";
int vocabLen = buffer.ReadInt();
for (int i = 0; i < vocabLen; i++) {
Expand Down Expand Up @@ -1441,6 +1457,10 @@ namespace fastllm {
this->dicts[key] = value;
}

void WeightMap::AddAdapterDict(const std::string &name, const std::string &key, const std::string &value) {
this->peftDict[name][key] = value;
}

void WeightMap::AddQLinearWeight(const std::string &key, const std::vector <int> &dims,
int bit, float *scales, uint8_t *oriData) {
AssertInFastLLM(bit == 4 || bit == 8, "Error: only support 8 bit or 4 bit QLinear.\n");
Expand Down Expand Up @@ -1783,6 +1803,72 @@ namespace fastllm {
}, {}, {{"axis", axis}, {"input0___batch", (int)input0.size()}, {"input1___batch", (int)input1.size()}});
}

void LoraLayer(Data &input, Data &weight, Data &loraA, Data &loraB, const Data &bias, Data &output,
std::map <std::string, std::string> loraConfig) {
float r = std::atof(loraConfig["r"].c_str());
float lora_alpha = std::atof(loraConfig["lora_alpha"].c_str());
bool fan_in_fan_out = loraConfig["fan_in_fan_out"] == "true";
if (r > 0) {
float scaling = lora_alpha / r;
if (fan_in_fan_out) {
Data weightTrans;
Data result, loraAOut, loraBOut;
Permute(weight, {1, 0}, weightTrans);
Linear(input, weightTrans, bias, result);
Linear(input, loraA, Data(), loraAOut);
Linear(loraAOut, loraB, Data(), loraBOut);
Mul(loraBOut, scaling, output);
AddTo(output, result);
} else {
Data result, loraAOut, loraBOut;
Linear(input, weight, bias, result);
Linear(input, loraA, Data(), loraAOut);
Linear(loraAOut, loraB, Data(), loraBOut);
Mul(loraBOut, scaling, output);
AddTo(output, result);
}
} else {
if (fan_in_fan_out) {
Data weightTrans;
Permute(weight, {1, 0}, weightTrans);
Linear(input, weightTrans, bias, output);
} else {
Linear(input, weight, bias, output);
}
}
}

void IA3Layer(Data &input, Data &weight, Data &ia3_l, Data &bias, Data &output,
std::map <std::string, std::string> ia3Config) {
bool is_feedforward = ia3Config["if_feedforward"] == "true";
bool fan_in_fan_out = ia3Config["fan_in_fan_out"] == "true";
if (is_feedforward) {
// IA3_L shape: (1, in_features)
// output = linear(input * ia3_l)
if (fan_in_fan_out) {
Data weightTrans;
Permute(weight, {1, 0}, weightTrans);
MulTo(input, ia3_l);
Linear(input, weightTrans, bias, output);
} else {
MulTo(input, ia3_l);
Linear(input, weight, bias, output);
}
} else {
// IA3_L shape: (out_features, 1)
// output = linear(input) * ia3_l
if (fan_in_fan_out) {
Data weightTrans;
Permute(weight, {1, 0}, weightTrans);
Linear(input, weightTrans, bias, output);
MulTo(output, ia3_l);
} else {
Linear(input, weight, bias, output);
MulTo(output, ia3_l);
}
}
}

void ClearProfiler() {
curExecutor->ClearProfiler();
}
Expand Down
11 changes: 11 additions & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,4 +526,15 @@ namespace fastllm {
const std::vector<std::map<std::string, int>> &params, fastllm::Data &inputIds,
fastllm::Data &attentionMask, fastllm::Data &positionIds) {
}

void basellm::SetAdapter(const std::string &name) {
if (weight.peftDict.find(name) == weight.peftDict.end()) {
ErrorInFastLLM("Can`t find adapter name: " + name);
}
adapterName = name;
}

void basellm::DisableAdapter() {
adapterName = "";
}
}
28 changes: 26 additions & 2 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,19 @@ namespace fastllm {
}
std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight";
std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias";
Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv);
if (!adapterName.empty()) {
std::string peftType = weight.peftDict[adapterName]["peft_type"];
if (peftType == "LORA") {
std::string loraAWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_A." + adapterName + ".weight";
std::string loraBWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_B." + adapterName + ".weight";
LoraLayer(attenInput, weight[qkvWeightName], weight[loraAWeightName], weight[loraBWeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]);
} else if (peftType == "IA3") {
std::string ia3WeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.ia3_l" + adapterName + ".weight";
IA3Layer(attenInput, weight[qkvWeightName], weight[ia3WeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]);
}
} else {
Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv);
}
if (version == 1) {
qkv.Reshape({qkv.dims[0], qkv.dims[1], num_attention_heads, -1});
int per = qkv.dims.back() / 3;
Expand Down Expand Up @@ -394,7 +406,19 @@ namespace fastllm {

std::string qkvWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.weight";
std::string qkvBiasName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.bias";
Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv);
if (!adapterName.empty()) {
std::string peftType = weight.peftDict[adapterName]["peft_type"];
if (peftType == "LORA") {
std::string loraAWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_A." + adapterName + ".weight";
std::string loraBWeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.lora_B." + adapterName + ".weight";
LoraLayer(attenInput, weight[qkvWeightName], weight[loraAWeightName], weight[loraBWeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]);
} else if (peftType == "IA3") {
std::string ia3WeightName = weightPre + std::to_string(i) + weightMiddle + ".query_key_value.ia3_l" + adapterName + ".weight";
IA3Layer(attenInput, weight[qkvWeightName], weight[ia3WeightName], weight[qkvBiasName], qkv, weight.peftDict[adapterName]);
}
} else {
Linear(attenInput, weight[qkvWeightName], weight[qkvBiasName], qkv);
}

if (version == 1) {
qkv.Reshape({qkv.dims[0], qkv.dims[1], num_attention_heads, -1});
Expand Down
21 changes: 19 additions & 2 deletions tools/fastllm_pytools/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,27 @@ def create(model,
if (isinstance(m, torch.nn.Embedding)):
weight_type_dict[key] = "embedding";

peft_config = {}
active_adapter = ""
if hasattr(model, "peft_config"):
peft_config = model.peft_config
if hasattr(model, "active_adapter"):
active_adapter = model.active_adapter

model = model.cpu();
dict = model.state_dict();
model_type = model.config.__dict__["model_type"];
model = llm.fastllm_lib.create_empty_llm_model(model_type.encode());
for it in modelInfo.keys():
llm.fastllm_lib.add_dict_llm_model(model, str(it).encode(), str(modelInfo[it]).encode());

for adapter_name in peft_config.keys():
adapter_dict = peft_config[adapter_name].__dict__
for it in adapter_dict.keys():
llm.fastllm_lib.add_adapter_dict_llm_model(model, str(adapter_name).encode(), str(it).encode(), str(adapter_dict[it]).encode())
if len(active_adapter) != 0:
llm.fastllm_lib.set_adapter(model, str(active_adapter).encode())

# 1. vocab
if (tokenizer):
if (hasattr(tokenizer, "tokenizer")):
Expand Down Expand Up @@ -110,15 +124,18 @@ def create(model,
# TODO bfloat
to_data_type = 0;

weight_name = key
if peft_config is not None:
weight_name = weight_name.replace('base_model.model.', '')
if (cur_weight_type == 111):
llm.fastllm_lib.add_qlinear_weight_llm_model(model, key.encode(),
llm.fastllm_lib.add_qlinear_weight_llm_model(model, weight_name.encode(),
len(dict[key].shape),
(ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
weight_bits[key],
dict[key + "_scale"].numpy().astype(np.float32).ctypes.data_as(ctypes.c_void_p),
dict[key].numpy().ctypes.data_as(ctypes.c_void_p));
else:
llm.fastllm_lib.add_weight_llm_model(model, key.encode(),
llm.fastllm_lib.add_weight_llm_model(model, weight_name.encode(),
len(dict[key].shape),
(ctypes.c_int * len(dict[key].shape))(*list(dict[key].shape)),
to_data_type, cur_weight_type, ori_data_type,
Expand Down
6 changes: 6 additions & 0 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,9 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No
yield response, new_history, None;
else:
yield response, new_history;

def set_adapter(self, name: str):
fastllm_lib.set_adapter(self.model, str(name).encode())

def disable_adapter(self):
fastllm_lib.disable_adapter(self.model)
22 changes: 20 additions & 2 deletions tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,22 @@ def tofile(exportPath,

modelInfo["tokenizer_use_score"] = "1" # 分词带分数

if hasattr(model, "peft_config"):
adapter_size = len(model.peft_config)
modelInfo["peft_size"] = adapter_size

fo.write(struct.pack('i', len(modelInfo)))
for it in modelInfo.keys():
writeKeyValue(fo, str(it), str(modelInfo[it]))

if hasattr(model, "peft_config"):
for adapter_name in model.peft_config.keys():
adapter_dict = model.peft_config[adapter_name].__dict__
writeString(fo, adapter_name)
fo.write(struct.pack('i', len(adapter_dict)))
for it in adapter_dict.keys():
writeKeyValue(fo, str(it), str(adapter_dict[it]))

# 1. vocab
if (tokenizer):
if (hasattr(tokenizer, "tokenizer")):
Expand Down Expand Up @@ -166,8 +178,14 @@ def tofile(exportPath,
ori_np_data_type = np.float16

cur = dict[key].numpy().astype(ori_np_data_type)
fo.write(struct.pack('i', len(key)))
fo.write(key.encode())

if hasattr(model, "peft_config"):
weight_name = key.replace('base_model.model.', '')
fo.write(struct.pack('i', len(weight_name)))
fo.write(weight_name.encode())
else:
fo.write(struct.pack('i', len(key)))
fo.write(key.encode())
fo.write(struct.pack('i', len(cur.shape)))
for i in cur.shape:
fo.write(struct.pack('i', i))
Expand Down
18 changes: 18 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ extern "C" {
return;
}

DLL_EXPORT void add_adapter_dict_llm_model(int modelId, char *adapterName, char *key, char *value) {
auto model = models.GetModel(modelId);
model->weight.AddAdapterDict(adapterName, key, value);
return;
}

DLL_EXPORT void set_adapter(int modelId, char *name) {
auto model = models.GetModel(modelId);
model->SetAdapter(name);
return;
}

DLL_EXPORT void disable_adapter(int modelId, char *name) {
auto model = models.GetModel(modelId);
model->DisableAdapter();
return;
}

DLL_EXPORT void init_params_llm_model(int modelId) {
auto model = models.GetModel(modelId);
model->InitParams();
Expand Down

0 comments on commit feffae3

Please sign in to comment.