Skip to content

Commit

Permalink
Merge pull request #136 from bjwswang/main
Browse files Browse the repository at this point in the history
feat: able to configure text splitter in local dataset management
  • Loading branch information
bjwswang authored Oct 23, 2023
2 parents b230ceb + 933416c commit 8a851e4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 50 deletions.
20 changes: 10 additions & 10 deletions arctl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ You can use `arctl` to manage your dataset locally with the following commands:


```shell
./bin/arctl dataset create -h
❯ arctl dataset create -h
Create dataset

Usage:
Expand All @@ -96,6 +96,7 @@ Flags:
--llm-apikey string apiKey to access embedding service
--llm-type string llm type to use(Only zhipuai,openai supported now) (default "zhipuai")
--name string dataset(namespace/collection) of the document to load into
--text-splitter string text splitter to use(Only character,token,markdown supported now) (default "character")
--vector-store string vector stores to use(Only chroma supported now) (default "http://127.0.0.1:8000")

Global Flags:
Expand Down Expand Up @@ -145,7 +146,7 @@ Global Flags:
For example:

```shell
./bin/arctl dataset list
❯ arctl dataset list
| DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | CHUNK SIZE | CHUNK OVERLAP |
| arcadia | 4 | zhipuai | http://localhost:8000 | text | 300 | 30 |
```
Expand Down Expand Up @@ -174,7 +175,7 @@ Required Arguments:
For example:

```shell
./bin/arctl dataset show --name arcadia
❯ arctl dataset show --name arcadia
I1012 17:57:17.026985 7609 dataset.go:206]
{
"name": "arcadia",
Expand All @@ -183,6 +184,7 @@ I1012 17:57:17.026985 7609 dataset.go:206]
"llm_api_key": "4fcceceb1666cd11808c218d6d619950.TCXUvaQCWFyIkxB3",
"vector_store": "http://localhost:8000",
"document_language": "text",
"text_splitter": "character",
"chunk_size": 300,
"chunk_overlap": 30,
"files": {
Expand Down Expand Up @@ -239,7 +241,7 @@ Required Arguments:
For example:

```shell
./bin/arctl dataset delete --name arcadia
❯ arctl dataset delete --name arcadia
I1012 18:06:04.894410 8786 dataset.go:272] Delete dataset: arcadia
I1012 18:06:04.894985 8786 dataset.go:303] Successfully delete dataset: arcadia
```
Expand All @@ -255,19 +257,17 @@ Usage:
arctl chat [usage] [flags]

Flags:
--dataset string dataset(namespace/collection) to query from (default "arcadia")
--enable-embedding-search enable embedding similarity search(false by default)
--dataset string dataset(namespace/collection) to query from
-h, --help help for chat
--llm-apikey string apiKey to access embedding/llm service.Must required when embedding similarity search is enabled
--llm-apikey string apiKey to access llm service.Must required when embedding similarity search is enabled
--llm-type string llm type to use for embedding & chat(Only zhipuai,openai supported now) (default "zhipuai")
--method string Invoke method used when access LLM service(invoke/sse-invoke) (default "sse-invoke")
--model string which model to use: chatglm_lite/chatglm_std/chatglm_pro (default "chatglm_lite")
--num-docs int number of documents to be returned with SimilarSearch (default 5)
--question string question text to be asked
--score-threshold float score threshold for similarity search(Higher is better)
--score-threshold float32 score threshold for similarity search(Higher is better)
--temperature float32 temperature for chat (default 0.95)
--top-p float32 top-p for chat (default 0.7)
--vector-store string vector stores to use(Only chroma supported now) (default "http://localhost:8000")

Global Flags:
--home string home directory to use (default "/Users/bjwswang/.arcadia")
Expand Down Expand Up @@ -308,7 +308,7 @@ Arcadia 开发的游戏中最知名的作品之一是《Second Life》(第二
> This will chat with LLM `zhipuai` with its apikey by using model `chatglm_pro` with embedding enabled
```shell
arctl chat --dataset arcadia--llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --num-docs 10 --question "介绍一下Arcadia"
arctl chat --llm-apikey 26b2bc55fae40752055cadfc4792f9de.wagA4NIwg5aZJWhm --model chatglm_pro --num-docs 10 --question "介绍一下Arcadia" --dataset arcadia
```

Required Arguments:
Expand Down
3 changes: 0 additions & 3 deletions arctl/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ func NewChatCmd() *cobra.Command {
cmd.Flags().StringVar(&dataset, "dataset", "", "dataset(namespace/collection) to query from")
cmd.Flags().Float32Var(&scoreThreshold, "score-threshold", 0, "score threshold for similarity search(Higher is better)")
cmd.Flags().IntVar(&numDocs, "num-docs", 5, "number of documents to be returned with SimilarSearch")
if err = cmd.MarkFlagRequired("dataset"); err != nil {
panic(err)
}

// For LLM chat
cmd.Flags().StringVar(&llmType, "llm-type", string(llms.ZhiPuAI), "llm type to use for embedding & chat(Only zhipuai,openai supported now)")
Expand Down
97 changes: 61 additions & 36 deletions arctl/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ var (

vectorStore string
documentLanguage string
textSplitter string
chunkSize int
chunkOverlap int

Expand Down Expand Up @@ -78,7 +79,7 @@ func DatasetListCmd() *cobra.Command {
Use: "list [usage]",
Short: "List dataset",
RunE: func(cmd *cobra.Command, args []string) error {
fmt.Printf("| DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | CHUNK SIZE | CHUNK OVERLAP |\n")
fmt.Printf("| DATASET | FILES |EMBEDDING MODEL | VECTOR STORE | DOCUMENT LANGUAGE | TEXT SPLITTER | CHUNK SIZE | CHUNK OVERLAP |\n")
err = filepath.Walk(filepath.Join(home, "dataset"), func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
Expand All @@ -92,7 +93,7 @@ func DatasetListCmd() *cobra.Command {
return fmt.Errorf("failed to load cached dataset %s: %v", info.Name(), err)
}
// print item
fmt.Printf("| %s | %d | %s | %s | %s | %d | %d |\n", ds.Name, len(ds.Files), ds.LLMType, ds.VectorStore, ds.DocumentLanguage, ds.ChunkSize, ds.ChunkOverlap)
fmt.Printf("| %s | %d | %s | %s | %s | %s | %d | %d |\n", ds.Name, len(ds.Files), ds.LLMType, ds.VectorStore, ds.DocumentLanguage, ds.TextSplitter, ds.ChunkSize, ds.ChunkOverlap)
return nil
})
if err != nil {
Expand Down Expand Up @@ -126,6 +127,7 @@ func DatasetCreateCmd() *cobra.Command {
ds.LLMType = llmType
ds.VectorStore = vectorStore
ds.DocumentLanguage = documentLanguage
ds.TextSplitter = textSplitter
ds.ChunkSize = chunkSize
ds.ChunkOverlap = chunkOverlap

Expand All @@ -136,7 +138,6 @@ func DatasetCreateCmd() *cobra.Command {
}

// cache the dataset to local
klog.Infof("Caching dataset %s", dataset)
cache, err := json.Marshal(ds)
if err != nil {
return fmt.Errorf("failed to marshal dataset %s: %v", dataset, err)
Expand All @@ -145,8 +146,9 @@ func DatasetCreateCmd() *cobra.Command {
if err != nil {
return err
}
klog.Infof("Successfully created dataset %s", dataset)
return nil
klog.Infof("Successfully created dataset %s\n", dataset)

return showDataset(dataset)
},
}
cmd.Flags().StringVar(&dataset, "name", "", "dataset(namespace/collection) of the document to load into")
Expand All @@ -167,6 +169,7 @@ func DatasetCreateCmd() *cobra.Command {

cmd.Flags().StringVar(&vectorStore, "vector-store", "http://127.0.0.1:8000", "vector stores to use(Only chroma supported now)")
cmd.Flags().StringVar(&documentLanguage, "document-language", "text", "language of the document(Only text,html,csv supported now)")
cmd.Flags().StringVar(&textSplitter, "text-splitter", "character", "text splitter to use(Only character,token,markdown supported now)")
cmd.Flags().IntVar(&chunkSize, "chunk-size", 300, "chunk size for embedding")
cmd.Flags().IntVar(&chunkOverlap, "chunk-overlap", 30, "chunk overlap for embedding")

Expand All @@ -178,34 +181,8 @@ func DatasetShowCmd() *cobra.Command {
Use: "show [usage]",
Short: "Load more documents to dataset",
RunE: func(cmd *cobra.Command, args []string) error {
cachedDatasetFile, err := os.OpenFile(filepath.Join(home, "dataset", dataset), os.O_RDWR, 0644)
if err != nil {
if os.IsNotExist(err) {
klog.Errorf("dataset %s does not exist", dataset)
return nil
} else {
return fmt.Errorf("failed to open cached dataset file: %v", err)
}
}
defer cachedDatasetFile.Close()

data, err := io.ReadAll(cachedDatasetFile)
if err != nil {
return fmt.Errorf("failed to read cached dataset file: %v", err)
}
// Create a buffer to store the formatted JSON
var formattedJSON bytes.Buffer

// Indent and format the JSON
err = json.Indent(&formattedJSON, data, "", " ")
if err != nil {
return fmt.Errorf("failed to format cached dataset file: %v", err)
}

// print dataset
klog.Infof("\n%s", formattedJSON.String())

return nil
klog.Infof("Show dataset: %s \n", dataset)
return showDataset(dataset)
},
}

Expand All @@ -217,6 +194,37 @@ func DatasetShowCmd() *cobra.Command {
return cmd
}

func showDataset(dataset string) error {
cachedDatasetFile, err := os.OpenFile(filepath.Join(home, "dataset", dataset), os.O_RDWR, 0644)
if err != nil {
if os.IsNotExist(err) {
klog.Errorf("dataset %s does not exist", dataset)
return nil
} else {
return fmt.Errorf("failed to open cached dataset file: %v", err)
}
}
defer cachedDatasetFile.Close()

data, err := io.ReadAll(cachedDatasetFile)
if err != nil {
return fmt.Errorf("failed to read cached dataset file: %v", err)
}
// Create a buffer to store the formatted JSON
var formattedJSON bytes.Buffer

// Indent and format the JSON
err = json.Indent(&formattedJSON, data, "", " ")
if err != nil {
return fmt.Errorf("failed to format cached dataset file: %v", err)
}

// print dataset
klog.Infof("\n%s", formattedJSON.String())

return nil
}

func DatasetExecuteCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "execute [usage]",
Expand Down Expand Up @@ -326,6 +334,7 @@ type Dataset struct {
// Parameters for vectorization
VectorStore string `json:"vector_store"`
DocumentLanguage string `json:"document_language"`
TextSplitter string `json:"text_splitter"`
ChunkSize int `json:"chunk_size"`
ChunkOverlap int `json:"chunk_overlap"`

Expand Down Expand Up @@ -448,9 +457,25 @@ func (cachedDS *Dataset) loadDocument(ctx context.Context, document string) erro
return errors.New("unsupported document language")
}

split := textsplitter.NewRecursiveCharacter()
split.ChunkSize = chunkSize
split.ChunkOverlap = chunkOverlap
// initliaze text splitter
var split textsplitter.TextSplitter
switch cachedDS.TextSplitter {
case "token":
split = textsplitter.NewTokenSplitter(
textsplitter.WithChunkSize(chunkSize),
textsplitter.WithChunkOverlap(chunkOverlap),
)
case "markdown":
split = textsplitter.NewMarkdownTextSplitter(
textsplitter.WithChunkSize(chunkSize),
textsplitter.WithChunkOverlap(chunkOverlap),
)
default:
split = textsplitter.NewRecursiveCharacter(
textsplitter.WithChunkSize(chunkSize),
textsplitter.WithChunkOverlap(chunkOverlap),
)
}

documents, err := loader.LoadAndSplit(ctx, split)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/embeddings/zhipuai/zhipuai.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (e ZhiPuAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float3
e.BatchSize,
)

emb := make([][]float32, len(texts))
emb := make([][]float32, 0, len(texts))
for _, texts := range batchedTexts {
curTextEmbeddings, err := e.client.CreateEmbedding(ctx, texts)
if err != nil {
Expand Down

0 comments on commit 8a851e4

Please sign in to comment.