Skip to content

Commit

Permalink
Merge pull request #117 from Abirdcfly/more
Browse files Browse the repository at this point in the history
feat: add baichuan-7b chatglm-6b in dashscope as llm
  • Loading branch information
bjwswang authored Oct 15, 2023
2 parents 3957517 + a04cf90 commit eaba1a3
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 11 deletions.
11 changes: 7 additions & 4 deletions examples/dashscope/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ func main() {
panic(err)
}
}
for _, model := range []dashscope.Model{dashscope.LLAMA27BCHATV2, dashscope.LLAMA213BCHATV2} {
for _, model := range []dashscope.Model{dashscope.LLAMA27BCHATV2, dashscope.LLAMA213BCHATV2, dashscope.BAICHUAN7BV1, dashscope.CHATGLM6BV2} {
klog.V(0).Infof("\nChat with %s\n", model)
resp, err := sampleChatWithLlama2(apiKey, model)
resp, err := sampleChatWithOthers(apiKey, model)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -107,9 +107,12 @@ func sampleChat(apiKey string, model dashscope.Model) (llms.Response, error) {
return client.Call(params.Marshal())
}

func sampleChatWithLlama2(apiKey string, model dashscope.Model) (llms.Response, error) {
func sampleChatWithOthers(apiKey string, model dashscope.Model) (llms.Response, error) {
client := dashscope.NewDashScope(apiKey, false)
params := dashscope.DefaultModelParams()
params := dashscope.DefaultModelParamsSimpleChat()
if model == dashscope.CHATGLM6BV2 {
params.Input.History = &[]string{}
}
params.Model = model
params.Input.Prompt = samplePrompt
return client.Call(params.Marshal())
Expand Down
4 changes: 3 additions & 1 deletion pkg/llms/dashscope/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ const (
// LLaMa2 系列大语言模型由 Meta 开发并公开发布,其规模从 70 亿到 700 亿参数不等。在灵积上提供的 llama2-7b-chat-v2 和 llama2-13b-chat-v2,分别为 7B 和 13B 规模的 LLaMa2 模型,针对对话场景微调优化后的版本。
LLAMA27BCHATV2 Model = "llama2-7b-chat-v2"
LLAMA213BCHATV2 Model = "llama2-13b-chat-v2"
BAICHUAN7BV1 Model = "baichuan-7b-v1" // baichuan-7B 是由百川智能开发的一个开源的大规模预训练模型。基于 Transformer 结构,在大约 1.2 万亿 tokens 上训练的 70 亿参数模型,支持中英双语,上下文窗口长度为 4096。在标准的中文和英文权威 benchmark(C-EVAL/MMLU)上均取得同尺寸最好的效果。
CHATGLM6BV2 Model = "chatglm-6b-v2" // ChatGLM2 模型是由智谱 AI 出品的大规模语言模型,它在灵积平台上的模型名称为 "chatglm-6b-v2".
EmbeddingV1 Model = "text-embedding-v1" // 通用文本向量 同步调用
EmbeddingAsyncV1 Model = "text-embedding-async-v1" // 通用文本向量 批处理调用
)
Expand Down Expand Up @@ -68,7 +70,7 @@ func (z *DashScope) Call(data []byte) (llms.Response, error) {
if err := params.Unmarshal(data); err != nil {
return nil, err
}
return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse, false)
return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse, false, params.Model)
}

func (z *DashScope) Validate() (llms.Response, error) {
Expand Down
16 changes: 12 additions & 4 deletions pkg/llms/dashscope/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"context"
"encoding/json"
"net/http"

"github.com/kubeagi/arcadia/pkg/llms"
)

func setHeaders(req *http.Request, token string, sse, async bool) {
Expand All @@ -39,8 +41,8 @@ func setHeaders(req *http.Request, token string, sse, async bool) {
req.Header.Set("Authorization", "Bearer "+token)
}

func parseHTTPResponse(resp *http.Response) (data *Response, err error) {
if err = json.NewDecoder(resp.Body).Decode(&data); err != nil {
func parseHTTPResponse(resp *http.Response, data llms.Response) (llms.Response, error) {
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return data, nil
Expand All @@ -61,11 +63,17 @@ func req(ctx context.Context, apiURL, token string, data []byte, sse, async bool

return http.DefaultClient.Do(req)
}
func do(ctx context.Context, apiURL, token string, data []byte, sse, async bool) (*Response, error) {
func do(ctx context.Context, apiURL, token string, data []byte, sse, async bool, model Model) (llms.Response, error) {
resp, err := req(ctx, apiURL, token, data, sse, async)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return parseHTTPResponse(resp)
var respData llms.Response
if model == CHATGLM6BV2 {
respData = &ResponseChatGLB6B{}
} else {
respData = &Response{}
}
return parseHTTPResponse(resp, respData)
}
11 changes: 10 additions & 1 deletion pkg/llms/dashscope/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ var _ llms.ModelParams = (*ModelParams)(nil)
type ModelParams struct {
Model Model `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters"`
Parameters Parameters `json:"parameters,omitempty"`
}

// +kubebuilder:object:generate=true

type Input struct {
Messages []Message `json:"messages,omitempty"`
Prompt string `json:"prompt,omitempty"`
History *[]string `json:"history,omitempty"`
}

type Parameters struct {
Expand Down Expand Up @@ -80,6 +81,14 @@ func DefaultModelParams() ModelParams {
},
}
}
func DefaultModelParamsSimpleChat() ModelParams {
return ModelParams{
Model: QWEN14BChat,
Input: Input{
Prompt: "",
},
}
}

func (params *ModelParams) Marshal() []byte {
data, err := json.Marshal(params)
Expand Down
43 changes: 42 additions & 1 deletion pkg/llms/dashscope/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ limitations under the License.
package dashscope

import (
"bytes"
"encoding/json"

"github.com/kubeagi/arcadia/pkg/llms"
)

var _ llms.Response = (*Response)(nil)
var _ llms.Response = (*ResponseChatGLB6B)(nil)

type CommonResponse struct {
// https://help.aliyun.com/zh/dashscope/response-status-codes
Expand All @@ -40,6 +42,7 @@ type Response struct {
type Output struct {
Choices []Choice `json:"choices,omitempty"`
Text string `json:"text,omitempty"`
History []string `json:"history,omitempty"`
}

type FinishReason string
Expand Down Expand Up @@ -77,5 +80,43 @@ func (response *Response) Bytes() []byte {
}

func (response *Response) String() string {
return string(response.Bytes())
if response.Output.Text != "" {
return response.Output.Text
}
buf := &bytes.Buffer{}
for _, c := range response.Output.Choices {
buf.WriteString(c.Message.Content)
}
return buf.String()
}

type ResponseChatGLB6B struct {
CommonResponse
Output struct {
Text struct {
Response string `json:"response,omitempty"`
} `json:"text,omitempty"`
History []string `json:"history,omitempty"`
} `json:"output"`
Usage Usage `json:"usage"`
}

func (r *ResponseChatGLB6B) Type() llms.LLMType {
return llms.DashScope
}

func (r *ResponseChatGLB6B) String() string {
return r.Output.Text.Response
}

func (r *ResponseChatGLB6B) Bytes() []byte {
bytes, err := json.Marshal(r)
if err != nil {
return []byte{}
}
return bytes
}

func (r *ResponseChatGLB6B) Unmarshal(bytes []byte) error {
return json.Unmarshal(bytes, r)
}
9 changes: 9 additions & 0 deletions pkg/llms/dashscope/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit eaba1a3

Please sign in to comment.