Skip to content

Commit

Permalink
Merge pull request #106 from Abirdcfly/embedding
Browse files Browse the repository at this point in the history
feat: add text-embedding in dashscope
  • Loading branch information
bjwswang authored Oct 11, 2023
2 parents ee313c7 + e482d76 commit a679da0
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 15 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,20 @@ To enhace the AI capability in Golang, we developed some packages.Here are the e
- [embedding](https://github.com/kubeagi/arcadia/tree/main/examples/embedding) shows how to embedes your document to vector store with embedding service
- [rbac](https://github.com/kubeagi/arcadia/blob/main/examples/rbac/main.go) shows how to inquiry the security risks in your RBAC with AI.
- [zhipuai](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go) shows how to use this [zhipuai client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai)
- [dashscope](https://github.com/kubeagi/arcadia/blob/main/examples/dashscope/main.go) shows how to use this [dashscope client](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/dashscope) to chat with qwen-7b-chat / qwen-14b-chat / llama2-7b-chat-v2 / llama2-13b-chat-v2 and use embedding with dashscope text-embedding-v1 / text-embedding-async-v1

### LLMs

-[ZhiPuAI(智谱 AI)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/zhipuai)
- [example](https://github.com/kubeagi/arcadia/blob/main/examples/zhipuai/main.go)
-[DashScope(灵积模型服务)](https://github.com/kubeagi/arcadia/tree/main/pkg/llms/dashscope), we now support
-[qwen-7b-chat(通义千问开源 7B 模型)](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-quick-start)
-[qwen-14b-chat(通义千问开源 14B 模型)](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-7b-14b-quick-star)
-[llama2-7b-chat-v2(LLaMa2 大语言模型 7B)](https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-4)
-[llama2-13b-chat-v2(LLaMa2 大语言模型 13B)](https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-4)
-[text-embedding-v1(通用文本向量 同步接口)](https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details)
-[text-embedding-async-v1(通用文本向量 批处理接口)](https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details)
- see [example](https://github.com/kubeagi/arcadia/blob/main/examples/dashscope/main.go)

### Embeddings

Expand Down
54 changes: 53 additions & 1 deletion examples/dashscope/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,22 @@ package main

import (
"context"
"fmt"
"os"
"time"

"github.com/kubeagi/arcadia/pkg/llms"
"github.com/kubeagi/arcadia/pkg/llms/dashscope"
"k8s.io/klog/v2"
)

const (
samplePrompt = "how to change a deployment's image?"
samplePrompt = "how to change a deployment's image?"
sampleEmbeddingTextURL = "https://gist.githubusercontent.com/Abirdcfly/e66c1fbd48dbdd89398123362660828b/raw/de30715a99f32b66959f3c4c96b53db82554fa40/demo.txt"
)

var sampleEmbeddingText = []string{"离离原上草", "一岁一枯荣", "野火烧不尽", "春风吹又生"}

func main() {
if len(os.Args) == 1 {
panic("api key is empty")
Expand Down Expand Up @@ -57,6 +62,38 @@ func main() {
klog.V(0).Infof("Response: \n %s\n", resp)
}
klog.Infoln("sample chat done")
klog.Infof("\nsample embedding start...\nwe use same embedding: %s to test\n", sampleEmbeddingText)
resp, err := sampleEmbedding(apiKey)
if err != nil {
panic(err)
}
klog.V(0).Infof("embedding sync call return: \n %+v\n", resp)

taskID, err := sampleEmbeddingAsync(apiKey, sampleEmbeddingTextURL)
if err != nil {
panic(err)
}
klog.V(0).Infof("embedding async call will return taskID: %s", taskID)
klog.V(0).Infoln("wait 3s to make the task done")
time.Sleep(3 * time.Second)
downloadURL, err := sampleEmbeddingAsyncGetTaskDetail(apiKey, taskID)
if err != nil {
panic(err)
}
klog.V(0).Infof("get download url: %s\n", downloadURL)
localfile := "/tmp/embedding.txt"
klog.V(0).Infof("download and extract the embedding file to %s...\n", localfile)
err = dashscope.DownloadAndExtract(downloadURL, localfile)
if err != nil {
panic(err)
}
content, err := os.ReadFile(localfile)
if err != nil {
panic(err)
}
klog.V(0).Infoln("show the embedding file content:")
fmt.Println(string(content))
klog.Infoln("sample embedding done")
}

func sampleChat(apiKey string, model dashscope.Model) (llms.Response, error) {
Expand Down Expand Up @@ -93,3 +130,18 @@ func sampleSSEChat(apiKey string, model dashscope.Model) error {
}
return nil
}

func sampleEmbedding(apiKey string) ([]dashscope.Embeddings, error) {
client := dashscope.NewDashScope(apiKey, false)
return client.CreateEmbedding(context.TODO(), sampleEmbeddingText, false)
}

func sampleEmbeddingAsync(apiKey string, url string) (string, error) {
client := dashscope.NewDashScope(apiKey, false)
return client.CreateEmbeddingAsync(context.TODO(), url, false)
}

func sampleEmbeddingAsyncGetTaskDetail(apiKey string, taskID string) (string, error) {
client := dashscope.NewDashScope(apiKey, false)
return client.GetTaskDetail(context.TODO(), taskID)
}
110 changes: 106 additions & 4 deletions pkg/llms/dashscope/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ package dashscope

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/kubeagi/arcadia/pkg/llms"
)

const (
DashScopeChatURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
DashScopeChatURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
DashScopeTextEmbeddingURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
DashScopeTaskURL = "https://dashscope.aliyuncs.com/api/v1/tasks/"
)

type Model string
Expand All @@ -34,8 +38,10 @@ const (
QWEN14BChat Model = "qwen-14b-chat"
QWEN7BChat Model = "qwen-7b-chat"
// LLaMa2 系列大语言模型由 Meta 开发并公开发布,其规模从 70 亿到 700 亿参数不等。在灵积上提供的 llama2-7b-chat-v2 和 llama2-13b-chat-v2,分别为 7B 和 13B 规模的 LLaMa2 模型,针对对话场景微调优化后的版本。
LLAMA27BCHATV2 Model = "llama2-7b-chat-v2"
LLAMA213BCHATV2 Model = "llama2-13b-chat-v2"
LLAMA27BCHATV2 Model = "llama2-7b-chat-v2"
LLAMA213BCHATV2 Model = "llama2-13b-chat-v2"
EmbeddingV1 Model = "text-embedding-v1" // 通用文本向量 同步调用
EmbeddingAsyncV1 Model = "text-embedding-async-v1" // 通用文本向量 批处理调用
)

var _ llms.LLM = (*DashScope)(nil)
Expand All @@ -62,9 +68,105 @@ func (z *DashScope) Call(data []byte) (llms.Response, error) {
if err := params.Unmarshal(data); err != nil {
return nil, err
}
return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse)
return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse, false)
}

func (z *DashScope) Validate() (llms.Response, error) {
return nil, errors.New("not implemented")
}

func (z *DashScope) CreateEmbedding(ctx context.Context, inputTexts []string, query bool) ([]Embeddings, error) {
textType := TextTypeDocument
if query {
textType = TextTypeQuery
}
reqBody := EmbeddingRequest{
Model: EmbeddingV1,
Input: EmbeddingInput{
EmbeddingInputSync: &EmbeddingInputSync{
Texts: inputTexts,
},
},
Parameters: EmbeddingParameters{
TextType: textType,
},
}
data, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
resp, err := req(ctx, DashScopeTextEmbeddingURL, z.apiKey, data, false, false)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respData := &EmbeddingResponse{}
if err := json.NewDecoder(resp.Body).Decode(respData); err != nil {
return nil, err
}
if respData.StatusCode != 200 && respData.StatusCode != 0 {
return nil, errors.New(respData.Message)
}
return respData.Output.Embeddings, nil
}

func (z *DashScope) CreateEmbeddingAsync(ctx context.Context, inputURL string, query bool) (taskID string, err error) {
textType := TextTypeDocument
if query {
textType = TextTypeQuery
}
reqBody := EmbeddingRequest{
Model: EmbeddingAsyncV1,
Input: EmbeddingInput{
EmbeddingInputAsync: &EmbeddingInputAsync{
URL: inputURL,
},
},
Parameters: EmbeddingParameters{
TextType: textType,
},
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
resp, err := req(ctx, DashScopeTextEmbeddingURL, z.apiKey, data, false, true)
if err != nil {
return "", err
}
defer resp.Body.Close()
respData := &EmbeddingResponse{}
if err := json.NewDecoder(resp.Body).Decode(respData); err != nil {
return "", err
}
if respData.StatusCode != 200 && respData.StatusCode != 0 {
return "", errors.New(respData.Message)
}
return respData.Output.TaskID, nil
}

func (z *DashScope) GetTaskDetail(ctx context.Context, taskID string) (outURL string, err error) {
resp, err := req(ctx, DashScopeTaskURL+taskID, z.apiKey, nil, false, false)
if err != nil {
return "", err
}
defer resp.Body.Close()
respData := &EmbeddingResponse{}
if err := json.NewDecoder(resp.Body).Decode(respData); err != nil {
return "", err
}
if respData.StatusCode != 200 && respData.StatusCode != 0 {
return "", errors.New(respData.Message)
}
data := respData.Output.EmbeddingOutputASync
if data == nil {
return "", fmt.Errorf("cant find data in resp:%+v", respData)
}
if data.TaskStatus != TaskStatusSucceeded {
return "", fmt.Errorf("taskStatus:%s, message:%s", data.TaskStatus, data.Message)
}
if data.URL != "" {
return data.URL, nil
}
return "", errors.New(respData.Message)
}
109 changes: 109 additions & 0 deletions pkg/llms/dashscope/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package dashscope

import (
"compress/gzip"
"io"
"net/http"
"os"
"path/filepath"
)

type EmbeddingRequest struct {
Model Model `json:"model"`
Input EmbeddingInput `json:"input"`
Parameters EmbeddingParameters `json:"parameters"`
}

type EmbeddingInput struct {
*EmbeddingInputSync
*EmbeddingInputAsync
}
type EmbeddingInputSync struct {
Texts []string `json:"texts,omitempty"`
}
type EmbeddingInputAsync struct {
URL string `json:"url,omitempty"`
}
type EmbeddingParameters struct {
TextType TextType `json:"text_type"`
}

type TextType string

const (
TextTypeQuery TextType = "query"
TextTypeDocument TextType = "document"
)

type EmbeddingResponse struct {
CommonResponse
Output EmbeddingOutput `json:"output"`
Usage EmbeddingUsage `json:"usage"`
}

type EmbeddingOutput struct {
*EmbeddingOutputSync
*EmbeddingOutputASync
}
type EmbeddingOutputSync struct {
Embeddings []Embeddings `json:"embeddings"`
}

type EmbeddingOutputASync struct {
TaskID string `json:"task_id"`
TaskStatus TaskStatus `json:"task_status"`
URL string `json:"url"`
SubmitTime string `json:"submit_time,omitempty"`
ScheduledTime string `json:"scheduled_time,omitempty"`
EndTime string `json:"end_time,omitempty"`
// when failed
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
type Embeddings struct {
TextIndex int `json:"text_index"`
Embedding []float64 `json:"embedding"`
}

type EmbeddingUsage struct {
TotalTokens int `json:"total_tokens"`
}

type TaskStatus string

const (
TaskStatusPending TaskStatus = "PENDING"
TaskStatusRunning TaskStatus = "RUNNING"
TaskStatusSucceeded TaskStatus = "SUCCEEDED"
TaskStatusFailed TaskStatus = "FAILED"
TaskStatusUnknown TaskStatus = "UNKNOWN"
)

func DownloadAndExtract(url string, dest string) error {
if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil {
return err
}
destFile, err := os.Create(dest)
if err != nil {
return err
}
defer destFile.Close()

resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()

gzReader, err := gzip.NewReader(resp.Body)
if err != nil {
return err
}
defer gzReader.Close()

_, err = io.Copy(destFile, gzReader)
if err != nil {
return err
}
return nil
}
20 changes: 14 additions & 6 deletions pkg/llms/dashscope/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"net/http"
)

func setHeaders(req *http.Request, token string, sse bool) {
func setHeaders(req *http.Request, token string, sse, async bool) {
if sse {
// req.Header.Set("Content-Type", "text/event-stream") // Although the documentation says we should do this, but will return a 400 error and the python sdk doesn't do this.
req.Header.Set("Content-Type", "application/json")
Expand All @@ -33,6 +33,9 @@ func setHeaders(req *http.Request, token string, sse bool) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "*/*")
}
if async {
req.Header.Set("X-DashScope-Async", "enable")
}
req.Header.Set("Authorization", "Bearer "+token)
}

Expand All @@ -43,18 +46,23 @@ func parseHTTPResponse(resp *http.Response) (data *Response, err error) {
return data, nil
}

func req(ctx context.Context, apiURL, token string, data []byte, sse bool) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(data))
func req(ctx context.Context, apiURL, token string, data []byte, sse, async bool) (resp *http.Response, err error) {
var req *http.Request
if len(data) == 0 {
req, err = http.NewRequestWithContext(ctx, "GET", apiURL, nil)
} else {
req, err = http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(data))
}
if err != nil {
return nil, err
}

setHeaders(req, token, sse)
setHeaders(req, token, sse, async)

return http.DefaultClient.Do(req)
}
func do(ctx context.Context, apiURL, token string, data []byte, sse bool) (*Response, error) {
resp, err := req(ctx, apiURL, token, data, sse)
func do(ctx context.Context, apiURL, token string, data []byte, sse, async bool) (*Response, error) {
resp, err := req(ctx, apiURL, token, data, sse, async)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit a679da0

Please sign in to comment.