From 933416c241c1af55fc7fc42876c08b32fd9bf27f Mon Sep 17 00:00:00 2001 From: bjwswang Date: Mon, 23 Oct 2023 14:51:08 +0800 Subject: [PATCH] feat: able to configure text splitter in local dataset management Signed-off-by: bjwswang --- arctl/README.md | 20 +++---- arctl/chat.go | 3 - arctl/dataset.go | 97 +++++++++++++++++++------------ pkg/embeddings/zhipuai/zhipuai.go | 2 +- 4 files changed, 72 insertions(+), 50 deletions(-) diff --git a/arctl/README.md b/arctl/README.md index 448c13343..9a341a83c 100644 --- a/arctl/README.md +++ b/arctl/README.md @@ -81,7 +81,7 @@ You can use `arctl` to manage your dataset locally with the following commands: ```shell -❯ ./bin/arctl dataset create -h +❯ arctl dataset create -h Create dataset Usage: @@ -96,6 +96,7 @@ Flags: --llm-apikey string apiKey to access embedding service --llm-type string llm type to use(Only zhipuai,openai supported now) (default "zhipuai") --name string dataset(namespace/collection) of the document to load into + --text-splitter string text splitter to use(Only character,token,markdown supported now) (default "character") --vector-store string vector stores to use(Only chroma supported now) (default "http://127.0.0.1:8000") Global Flags: @@ -145,7 +146,7 @@ Global Flags: For example: ```shell -❯ ./bin/arctl dataset list +❯ arctl dataset list | DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | CHUNK SIZE | CHUNK OVERLAP | | arcadia | 4 | zhipuai | http://localhost:8000 | text | 300 | 30 | ``` @@ -174,7 +175,7 @@ Required Arguments: For example: ```shell -❯ ./bin/arctl dataset show --name arcadia +❯ arctl dataset show --name arcadia I1012 17:57:17.026985 7609 dataset.go:206] { "name": "arcadia", @@ -183,6 +184,7 @@ I1012 17:57:17.026985 7609 dataset.go:206] "llm_api_key": "4fcceceb1666cd11808c218d6d619950.TCXUvaQCWFyIkxB3", "vector_store": "http://localhost:8000", "document_language": "text", + "text_splitter": "character", "chunk_size": 300, "chunk_overlap": 30, "files": { @@ -239,7 +241,7 @@ Required Arguments: For example: ```shell -❯ ./bin/arctl dataset delete --name arcadia +❯ arctl dataset delete --name arcadia I1012 18:06:04.894410 8786 dataset.go:272] Delete dataset: arcadia I1012 18:06:04.894985 8786 dataset.go:303] Successfully delete dataset: arcadia ``` @@ -255,19 +257,17 @@ Usage: arctl chat [usage] [flags] Flags: - --dataset string dataset(namespace/collection) to query from (default "arcadia") - --enable-embedding-search enable embedding similarity search(false by default) + --dataset string dataset(namespace/collection) to query from -h, --help help for chat - --llm-apikey string apiKey to access embedding/llm service.Must required when embedding similarity search is enabled + --llm-apikey string apiKey to access llm service.Must required when embedding similarity search is enabled --llm-type string llm type to use for embedding & chat(Only zhipuai,openai supported now) (default "zhipuai") --method string Invoke method used when access LLM service(invoke/sse-invoke) (default "sse-invoke") --model string which model to use: chatglm_lite/chatglm_std/chatglm_pro (default "chatglm_lite") --num-docs int number of documents to be returned with SimilarSearch (default 5) --question string question text to be asked - --score-threshold float score threshold for similarity search(Higher is better) + --score-threshold float32 score threshold for similarity search(Higher is better) --temperature float32 temperature for chat (default 0.95) --top-p float32 top-p for chat (default 0.7) - --vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000") Global Flags: --home string home directory to use (default "/Users/bjwswang/.arcadia") @@ -308,7 +308,7 @@ Arcadia 开发的游戏中最知名的作品之一是《Second Life》(第二 > This will chat with LLM `zhipuai` with its apikey by using model `chatglm_pro` with embedding enabled ```shell -arctl chat --dataset arcadia--llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --num-docs 10 --question "介绍一下Arcadia" +arctl chat --llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --num-docs 10 --question "介绍一下Arcadia" --dataset arcadia ``` Required Arguments: diff --git a/arctl/chat.go b/arctl/chat.go index 7b64b2e17..c6faa47ab 100644 --- a/arctl/chat.go +++ b/arctl/chat.go @@ -104,9 +104,6 @@ func NewChatCmd() *cobra.Command { cmd.Flags().StringVar(&dataset, "dataset", "", "dataset(namespace/collection) to query from") cmd.Flags().Float32Var(&scoreThreshold, "score-threshold", 0, "score threshold for similarity search(Higher is better)") cmd.Flags().IntVar(&numDocs, "num-docs", 5, "number of documents to be returned with SimilarSearch") - if err = cmd.MarkFlagRequired("dataset"); err != nil { - panic(err) - } // For LLM chat cmd.Flags().StringVar(&llmType, "llm-type", string(llms.ZhiPuAI), "llm type to use for embedding & chat(Only zhipuai,openai supported now)") diff --git a/arctl/dataset.go b/arctl/dataset.go index 3bffd8a7b..650a9afe6 100644 --- a/arctl/dataset.go +++ b/arctl/dataset.go @@ -41,6 +41,7 @@ var ( vectorStore string documentLanguage string + textSplitter string chunkSize int chunkOverlap int @@ -78,7 +79,7 @@ func DatasetListCmd() *cobra.Command { Use: "list [usage]", Short: "List dataset", RunE: func(cmd *cobra.Command, args []string) error { - fmt.Printf("| DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | CHUNK SIZE | CHUNK OVERLAP |\n") + fmt.Printf("| DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | TEXT SPLITTER | CHUNK SIZE | CHUNK OVERLAP |\n") err = filepath.Walk(filepath.Join(home, "dataset"), func(path string, info os.FileInfo, err error) error { if err != nil { return err @@ -92,7 +93,7 @@ func DatasetListCmd() *cobra.Command { return fmt.Errorf("failed to load cached dataset %s: %v", info.Name(), err) } // print item - fmt.Printf("| %s | %d | %s | %s | %s | %d | %d |\n", ds.Name, len(ds.Files), ds.LLMType, ds.VectorStore, ds.DocumentLanguage, ds.ChunkSize, ds.ChunkOverlap) + fmt.Printf("| %s | %d | %s | %s | %s | %s | %d | %d |\n", ds.Name, len(ds.Files), ds.LLMType, ds.VectorStore, ds.DocumentLanguage, ds.TextSplitter, ds.ChunkSize, ds.ChunkOverlap) return nil }) if err != nil { @@ -126,6 +127,7 @@ func DatasetCreateCmd() *cobra.Command { ds.LLMType = llmType ds.VectorStore = vectorStore ds.DocumentLanguage = documentLanguage + ds.TextSplitter = textSplitter ds.ChunkSize = chunkSize ds.ChunkOverlap = chunkOverlap @@ -136,7 +138,6 @@ func DatasetCreateCmd() *cobra.Command { } // cache the dataset to local - klog.Infof("Caching dataset %s", dataset) cache, err := json.Marshal(ds) if err != nil { return fmt.Errorf("failed to marshal dataset %s: %v", dataset, err) @@ -145,8 +146,9 @@ func DatasetCreateCmd() *cobra.Command { if err != nil { return err } - klog.Infof("Successfully created dataset %s", dataset) - return nil + klog.Infof("Successfully created dataset %s\n", dataset) + + return showDataset(dataset) }, } cmd.Flags().StringVar(&dataset, "name", "", "dataset(namespace/collection) of the document to load into") @@ -167,6 +169,7 @@ func DatasetCreateCmd() *cobra.Command { cmd.Flags().StringVar(&vectorStore, "vector-store", "http://127.0.0.1:8000", "vector stores to use(Only chroma supported now)") cmd.Flags().StringVar(&documentLanguage, "document-language", "text", "language of the document(Only text,html,csv supported now)") + cmd.Flags().StringVar(&textSplitter, "text-splitter", "character", "text splitter to use(Only character,token,markdown supported now)") cmd.Flags().IntVar(&chunkSize, "chunk-size", 300, "chunk size for embedding") cmd.Flags().IntVar(&chunkOverlap, "chunk-overlap", 30, "chunk overlap for embedding") @@ -178,34 +181,8 @@ func DatasetShowCmd() *cobra.Command { Use: "show [usage]", Short: "Load more documents to dataset", RunE: func(cmd *cobra.Command, args []string) error { - cachedDatasetFile, err := os.OpenFile(filepath.Join(home, "dataset", dataset), os.O_RDWR, 0644) - if err != nil { - if os.IsNotExist(err) { - klog.Errorf("dataset %s does not exist", dataset) - return nil - } else { - return fmt.Errorf("failed to open cached dataset file: %v", err) - } - } - defer cachedDatasetFile.Close() - - data, err := io.ReadAll(cachedDatasetFile) - if err != nil { - return fmt.Errorf("failed to read cached dataset file: %v", err) - } - // Create a buffer to store the formatted JSON - var formattedJSON bytes.Buffer - - // Indent and format the JSON - err = json.Indent(&formattedJSON, data, "", " ") - if err != nil { - return fmt.Errorf("failed to format cached dataset file: %v", err) - } - - // print dataset - klog.Infof("\n%s", formattedJSON.String()) - - return nil + klog.Infof("Show dataset: %s \n", dataset) + return showDataset(dataset) }, } @@ -217,6 +194,37 @@ func DatasetShowCmd() *cobra.Command { return cmd } +func showDataset(dataset string) error { + cachedDatasetFile, err := os.OpenFile(filepath.Join(home, "dataset", dataset), os.O_RDWR, 0644) + if err != nil { + if os.IsNotExist(err) { + klog.Errorf("dataset %s does not exist", dataset) + return nil + } else { + return fmt.Errorf("failed to open cached dataset file: %v", err) + } + } + defer cachedDatasetFile.Close() + + data, err := io.ReadAll(cachedDatasetFile) + if err != nil { + return fmt.Errorf("failed to read cached dataset file: %v", err) + } + // Create a buffer to store the formatted JSON + var formattedJSON bytes.Buffer + + // Indent and format the JSON + err = json.Indent(&formattedJSON, data, "", " ") + if err != nil { + return fmt.Errorf("failed to format cached dataset file: %v", err) + } + + // print dataset + klog.Infof("\n%s", formattedJSON.String()) + + return nil +} + func DatasetExecuteCmd() *cobra.Command { cmd := &cobra.Command{ Use: "execute [usage]", @@ -326,6 +334,7 @@ type Dataset struct { // Parameters for vectorization VectorStore string `json:"vector_store"` DocumentLanguage string `json:"document_language"` + TextSplitter string `json:"text_splitter"` ChunkSize int `json:"chunk_size"` ChunkOverlap int `json:"chunk_overlap"` @@ -448,9 +457,25 @@ func (cachedDS *Dataset) loadDocument(ctx context.Context, document string) erro return errors.New("unsupported document language") } - split := textsplitter.NewRecursiveCharacter() - split.ChunkSize = chunkSize - split.ChunkOverlap = chunkOverlap + // initliaze text splitter + var split textsplitter.TextSplitter + switch cachedDS.TextSplitter { + case "token": + split = textsplitter.NewTokenSplitter( + textsplitter.WithChunkSize(chunkSize), + textsplitter.WithChunkOverlap(chunkOverlap), + ) + case "markdown": + split = textsplitter.NewMarkdownTextSplitter( + textsplitter.WithChunkSize(chunkSize), + textsplitter.WithChunkOverlap(chunkOverlap), + ) + default: + split = textsplitter.NewRecursiveCharacter( + textsplitter.WithChunkSize(chunkSize), + textsplitter.WithChunkOverlap(chunkOverlap), + ) + } documents, err := loader.LoadAndSplit(ctx, split) if err != nil { diff --git a/pkg/embeddings/zhipuai/zhipuai.go b/pkg/embeddings/zhipuai/zhipuai.go index 45e338e33..26e4a977a 100644 --- a/pkg/embeddings/zhipuai/zhipuai.go +++ b/pkg/embeddings/zhipuai/zhipuai.go @@ -51,7 +51,7 @@ func (e ZhiPuAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float3 e.BatchSize, ) - emb := make([][]float32, len(texts)) + emb := make([][]float32, 0, len(texts)) for _, texts := range batchedTexts { curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts) if err != nil {