diff --git a/README.md b/README.md index 89d27962c..516cbb689 100644 --- a/README.md +++ b/README.md @@ -68,11 +68,20 @@ To enhace the AI capability in Golang, we developed some packages.Here are the e - [embedding](https://github.com/kubeagi/arcadia/tree/main/examples/embedding) shows how to embedes your document to vector store with embedding service - [rbac](https://github.com/kubeagi/arcadia/blob/main/examples/rbac/main.go) shows how to inquiry the security risks in your RBAC with AI. - [zhipuai](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) shows how to use this [zhipuai client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai) +- [dashscope](https://github.com/kubeagi/arcadia/blob/main/examples/dashscope/main.go) shows how to use this [dashscope client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/dashscope) to chat with qwen-7b-chat / qwen-14b-chat / llama2-7b-chat-v2 / llama2-13b-chat-v2 and use embedding with dashscope text-embedding-v1 / text-embedding-async-v1 ### LLMs - ✅ [ZhiPuAI(智谱 AI)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai) - [example](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) +- ✅ [DashScope(灵积模型服务)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/dashscope), we now support + - ✅ [qwen-7b-chat(通义千问开源 7B 模型)](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-quick-start) + - ✅ [qwen-14b-chat(通义千问开源 14B 模型)](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-quick-star) + - ✅ [llama2-7b-chat-v2(LLaMa2 大语言模型 7B)](https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-4) + - ✅ [llama2-13b-chat-v2(LLaMa2 大语言模型 13B)](https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-4) + - ✅ [text-embedding-v1(通用文本向量 同步接口)](https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details) + - ✅ [text-embedding-async-v1(通用文本向量 批处理接口)](https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details) + - see [example](https://github.com/kubeagi/arcadia/blob/main/examples/dashscope/main.go) ### Embeddings diff --git a/examples/dashscope/main.go b/examples/dashscope/main.go index 751241a3b..8188af17d 100644 --- a/examples/dashscope/main.go +++ b/examples/dashscope/main.go @@ -18,7 +18,9 @@ package main import ( "context" + "fmt" "os" + "time" "github.com/kubeagi/arcadia/pkg/llms" "github.com/kubeagi/arcadia/pkg/llms/dashscope" @@ -26,9 +28,12 @@ import ( ) const ( - samplePrompt = "how to change a deployment's image?" + samplePrompt = "how to change a deployment's image?" + sampleEmbeddingTextURL = "https://gist.githubusercontent.com/Abirdcfly/e66c1fbd48dbdd89398123362660828b/raw/de30715a99f32b66959f3c4c96b53db82554fa40/demo.txt" ) +var sampleEmbeddingText = []string{"离离原上草", "一岁一枯荣", "野火烧不尽", "春风吹又生"} + func main() { if len(os.Args) == 1 { panic("api key is empty") @@ -57,6 +62,38 @@ func main() { klog.V(0).Infof("Response: \n %s\n", resp) } klog.Infoln("sample chat done") + klog.Infof("\nsample embedding start...\nwe use same embedding: %s to test\n", sampleEmbeddingText) + resp, err := sampleEmbedding(apiKey) + if err != nil { + panic(err) + } + klog.V(0).Infof("embedding sync call return: \n %+v\n", resp) + + taskID, err := sampleEmbeddingAsync(apiKey, sampleEmbeddingTextURL) + if err != nil { + panic(err) + } + klog.V(0).Infof("embedding async call will return taskID: %s", taskID) + klog.V(0).Infoln("wait 3s to make the task done") + time.Sleep(3 * time.Second) + downloadURL, err := sampleEmbeddingAsyncGetTaskDetail(apiKey, taskID) + if err != nil { + panic(err) + } + klog.V(0).Infof("get download url: %s\n", downloadURL) + localfile := "/tmp/embedding.txt" + klog.V(0).Infof("download and extract the embedding file to %s...\n", localfile) + err = dashscope.DownloadAndExtract(downloadURL, localfile) + if err != nil { + panic(err) + } + content, err := os.ReadFile(localfile) + if err != nil { + panic(err) + } + klog.V(0).Infoln("show the embedding file content:") + fmt.Println(string(content)) + klog.Infoln("sample embedding done") } func sampleChat(apiKey string, model dashscope.Model) (llms.Response, error) { @@ -93,3 +130,18 @@ func sampleSSEChat(apiKey string, model dashscope.Model) error { } return nil } + +func sampleEmbedding(apiKey string) ([]dashscope.Embeddings, error) { + client := dashscope.NewDashScope(apiKey, false) + return client.CreateEmbedding(context.TODO(), sampleEmbeddingText, false) +} + +func sampleEmbeddingAsync(apiKey string, url string) (string, error) { + client := dashscope.NewDashScope(apiKey, false) + return client.CreateEmbeddingAsync(context.TODO(), url, false) +} + +func sampleEmbeddingAsyncGetTaskDetail(apiKey string, taskID string) (string, error) { + client := dashscope.NewDashScope(apiKey, false) + return client.GetTaskDetail(context.TODO(), taskID) +} diff --git a/pkg/llms/dashscope/api.go b/pkg/llms/dashscope/api.go index f84c13a49..6c6ca9e6c 100644 --- a/pkg/llms/dashscope/api.go +++ b/pkg/llms/dashscope/api.go @@ -18,13 +18,17 @@ package dashscope import ( "context" + "encoding/json" "errors" + "fmt" "github.com/kubeagi/arcadia/pkg/llms" ) const ( - DashScopeChatURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + DashScopeChatURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + DashScopeTextEmbeddingURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + DashScopeTaskURL = "https://dashscope.aliyuncs.com/api/v1/tasks/" ) type Model string @@ -34,8 +38,10 @@ const ( QWEN14BChat Model = "qwen-14b-chat" QWEN7BChat Model = "qwen-7b-chat" // LLaMa2 系列大语言模型由 Meta 开发并公开发布,其规模从 70 亿到 700 亿参数不等。在灵积上提供的 llama2-7b-chat-v2 和 llama2-13b-chat-v2,分别为 7B 和 13B 规模的 LLaMa2 模型,针对对话场景微调优化后的版本。 - LLAMA27BCHATV2 Model = "llama2-7b-chat-v2" - LLAMA213BCHATV2 Model = "llama2-13b-chat-v2" + LLAMA27BCHATV2 Model = "llama2-7b-chat-v2" + LLAMA213BCHATV2 Model = "llama2-13b-chat-v2" + EmbeddingV1 Model = "text-embedding-v1" // 通用文本向量 同步调用 + EmbeddingAsyncV1 Model = "text-embedding-async-v1" // 通用文本向量 批处理调用 ) var _ llms.LLM = (*DashScope)(nil) @@ -62,9 +68,105 @@ func (z *DashScope) Call(data []byte) (llms.Response, error) { if err := params.Unmarshal(data); err != nil { return nil, err } - return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse) + return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse, false) } func (z *DashScope) Validate() (llms.Response, error) { return nil, errors.New("not implemented") } + +func (z *DashScope) CreateEmbedding(ctx context.Context, inputTexts []string, query bool) ([]Embeddings, error) { + textType := TextTypeDocument + if query { + textType = TextTypeQuery + } + reqBody := EmbeddingRequest{ + Model: EmbeddingV1, + Input: EmbeddingInput{ + EmbeddingInputSync: &EmbeddingInputSync{ + Texts: inputTexts, + }, + }, + Parameters: EmbeddingParameters{ + TextType: textType, + }, + } + data, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + resp, err := req(ctx, DashScopeTextEmbeddingURL, z.apiKey, data, false, false) + if err != nil { + return nil, err + } + defer resp.Body.Close() + respData := &EmbeddingResponse{} + if err := json.NewDecoder(resp.Body).Decode(respData); err != nil { + return nil, err + } + if respData.StatusCode != 200 && respData.StatusCode != 0 { + return nil, errors.New(respData.Message) + } + return respData.Output.Embeddings, nil +} + +func (z *DashScope) CreateEmbeddingAsync(ctx context.Context, inputURL string, query bool) (taskID string, err error) { + textType := TextTypeDocument + if query { + textType = TextTypeQuery + } + reqBody := EmbeddingRequest{ + Model: EmbeddingAsyncV1, + Input: EmbeddingInput{ + EmbeddingInputAsync: &EmbeddingInputAsync{ + URL: inputURL, + }, + }, + Parameters: EmbeddingParameters{ + TextType: textType, + }, + } + data, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + resp, err := req(ctx, DashScopeTextEmbeddingURL, z.apiKey, data, false, true) + if err != nil { + return "", err + } + defer resp.Body.Close() + respData := &EmbeddingResponse{} + if err := json.NewDecoder(resp.Body).Decode(respData); err != nil { + return "", err + } + if respData.StatusCode != 200 && respData.StatusCode != 0 { + return "", errors.New(respData.Message) + } + return respData.Output.TaskID, nil +} + +func (z *DashScope) GetTaskDetail(ctx context.Context, taskID string) (outURL string, err error) { + resp, err := req(ctx, DashScopeTaskURL+taskID, z.apiKey, nil, false, false) + if err != nil { + return "", err + } + defer resp.Body.Close() + respData := &EmbeddingResponse{} + if err := json.NewDecoder(resp.Body).Decode(respData); err != nil { + return "", err + } + if respData.StatusCode != 200 && respData.StatusCode != 0 { + return "", errors.New(respData.Message) + } + data := respData.Output.EmbeddingOutputASync + if data == nil { + return "", fmt.Errorf("cant find data in resp:%+v", respData) + } + if data.TaskStatus != TaskStatusSucceeded { + return "", fmt.Errorf("taskStatus:%s, message:%s", data.TaskStatus, data.Message) + } + if data.URL != "" { + return data.URL, nil + } + return "", errors.New(respData.Message) +} diff --git a/pkg/llms/dashscope/embedding.go b/pkg/llms/dashscope/embedding.go new file mode 100644 index 000000000..bf9e6cfbc --- /dev/null +++ b/pkg/llms/dashscope/embedding.go @@ -0,0 +1,109 @@ +package dashscope + +import ( + "compress/gzip" + "io" + "net/http" + "os" + "path/filepath" +) + +type EmbeddingRequest struct { + Model Model `json:"model"` + Input EmbeddingInput `json:"input"` + Parameters EmbeddingParameters `json:"parameters"` +} + +type EmbeddingInput struct { + *EmbeddingInputSync + *EmbeddingInputAsync +} +type EmbeddingInputSync struct { + Texts []string `json:"texts,omitempty"` +} +type EmbeddingInputAsync struct { + URL string `json:"url,omitempty"` +} +type EmbeddingParameters struct { + TextType TextType `json:"text_type"` +} + +type TextType string + +const ( + TextTypeQuery TextType = "query" + TextTypeDocument TextType = "document" +) + +type EmbeddingResponse struct { + CommonResponse + Output EmbeddingOutput `json:"output"` + Usage EmbeddingUsage `json:"usage"` +} + +type EmbeddingOutput struct { + *EmbeddingOutputSync + *EmbeddingOutputASync +} +type EmbeddingOutputSync struct { + Embeddings []Embeddings `json:"embeddings"` +} + +type EmbeddingOutputASync struct { + TaskID string `json:"task_id"` + TaskStatus TaskStatus `json:"task_status"` + URL string `json:"url"` + SubmitTime string `json:"submit_time,omitempty"` + ScheduledTime string `json:"scheduled_time,omitempty"` + EndTime string `json:"end_time,omitempty"` + // when failed + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} +type Embeddings struct { + TextIndex int `json:"text_index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingUsage struct { + TotalTokens int `json:"total_tokens"` +} + +type TaskStatus string + +const ( + TaskStatusPending TaskStatus = "PENDING" + TaskStatusRunning TaskStatus = "RUNNING" + TaskStatusSucceeded TaskStatus = "SUCCEEDED" + TaskStatusFailed TaskStatus = "FAILED" + TaskStatusUnknown TaskStatus = "UNKNOWN" +) + +func DownloadAndExtract(url string, dest string) error { + if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { + return err + } + destFile, err := os.Create(dest) + if err != nil { + return err + } + defer destFile.Close() + + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + return err + } + defer gzReader.Close() + + _, err = io.Copy(destFile, gzReader) + if err != nil { + return err + } + return nil +} diff --git a/pkg/llms/dashscope/http_client.go b/pkg/llms/dashscope/http_client.go index 8485b731a..e2775cac6 100644 --- a/pkg/llms/dashscope/http_client.go +++ b/pkg/llms/dashscope/http_client.go @@ -23,7 +23,7 @@ import ( "net/http" ) -func setHeaders(req *http.Request, token string, sse bool) { +func setHeaders(req *http.Request, token string, sse, async bool) { if sse { // req.Header.Set("Content-Type", "text/event-stream") // Although the documentation says we should do this, but will return a 400 error and the python sdk doesn't do this. req.Header.Set("Content-Type", "application/json") @@ -33,6 +33,9 @@ func setHeaders(req *http.Request, token string, sse bool) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "*/*") } + if async { + req.Header.Set("X-DashScope-Async", "enable") + } req.Header.Set("Authorization", "Bearer "+token) } @@ -43,18 +46,23 @@ func parseHTTPResponse(resp *http.Response) (data *Response, err error) { return data, nil } -func req(ctx context.Context, apiURL, token string, data []byte, sse bool) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(data)) +func req(ctx context.Context, apiURL, token string, data []byte, sse, async bool) (resp *http.Response, err error) { + var req *http.Request + if len(data) == 0 { + req, err = http.NewRequestWithContext(ctx, "GET", apiURL, nil) + } else { + req, err = http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(data)) + } if err != nil { return nil, err } - setHeaders(req, token, sse) + setHeaders(req, token, sse, async) return http.DefaultClient.Do(req) } -func do(ctx context.Context, apiURL, token string, data []byte, sse bool) (*Response, error) { - resp, err := req(ctx, apiURL, token, data, sse) +func do(ctx context.Context, apiURL, token string, data []byte, sse, async bool) (*Response, error) { + resp, err := req(ctx, apiURL, token, data, sse, async) if err != nil { return nil, err } diff --git a/pkg/llms/dashscope/response.go b/pkg/llms/dashscope/response.go index 19d70554e..d643d9c23 100644 --- a/pkg/llms/dashscope/response.go +++ b/pkg/llms/dashscope/response.go @@ -24,15 +24,18 @@ import ( var _ llms.Response = (*Response)(nil) -type Response struct { +type CommonResponse struct { // https://help.aliyun.com/zh/dashscope/response-status-codes StatusCode int `json:"status_code,omitempty"` Code string `json:"code,omitempty"` Message string `json:"message,omitempty"` - Output Output `json:"output"` - Usage Usage `json:"usage"` RequestID string `json:"request_id"` } +type Response struct { + CommonResponse + Output Output `json:"output"` + Usage Usage `json:"usage"` +} type Output struct { Choices []Choice `json:"choices,omitempty"` diff --git a/pkg/llms/dashscope/sse_client.go b/pkg/llms/dashscope/sse_client.go index 881272df1..fbb69bcd5 100644 --- a/pkg/llms/dashscope/sse_client.go +++ b/pkg/llms/dashscope/sse_client.go @@ -49,7 +49,7 @@ func defaultHandler(event *sse.Event, last string) (newData string) { return "" } func (z *DashScope) StreamCall(ctx context.Context, data []byte, handler func(event *sse.Event, last string) (data string)) error { - resp, err := req(ctx, DashScopeChatURL, z.apiKey, data, true) + resp, err := req(ctx, DashScopeChatURL, z.apiKey, data, true, false) if err != nil { return err }