From b01bbdb3672ccd2232157717972fe128c9387c76 Mon Sep 17 00:00:00 2001 From: Abirdcfly Date: Wed, 10 Jan 2024 13:44:26 +0800 Subject: [PATCH] feat: knowledgebase support more granular update Signed-off-by: Abirdcfly --- .github/workflows/pgvector_image_build.yml | 2 + api/base/v1alpha1/knowledgebase.go | 5 + .../app-node/chain/llmchain_controller.go | 4 + .../chain/retrieval_qa_chain_controller.go | 4 + .../app-node/prompt/prompt_controller.go | 4 + .../knowledgebase_retriever_controller.go | 4 + controllers/base/application_controller.go | 4 + controllers/base/knowledgebase_controller.go | 136 ++++++++---- deploy/charts/arcadia/Chart.yaml | 2 +- deploy/charts/arcadia/templates/config.yaml | 9 +- .../arcadia/templates/post-vectorstore.yaml | 2 + deploy/charts/arcadia/values.yaml | 6 +- main.go | 3 +- pkg/llms/zhipuai/api.go | 4 +- pkg/vectorstore/pgvector.go | 195 ++++++++++++++++++ pkg/vectorstore/vectorstore.go | 121 ++++------- tests/deploy-values.yaml | 3 + tests/example-test.sh | 55 ++++- 18 files changed, 429 insertions(+), 134 deletions(-) create mode 100644 pkg/vectorstore/pgvector.go diff --git a/.github/workflows/pgvector_image_build.yml b/.github/workflows/pgvector_image_build.yml index 73eaba655..170960ab4 100644 --- a/.github/workflows/pgvector_image_build.yml +++ b/.github/workflows/pgvector_image_build.yml @@ -3,6 +3,8 @@ name: Build pgvector images on: pull_request: branches: [main] + paths: + - 'deploy/pgvector/Dockerfile' push: branches: [main] paths: diff --git a/api/base/v1alpha1/knowledgebase.go b/api/base/v1alpha1/knowledgebase.go index de685effa..a79b51e71 100644 --- a/api/base/v1alpha1/knowledgebase.go +++ b/api/base/v1alpha1/knowledgebase.go @@ -5,6 +5,11 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const ( + // UpdateSourceFileAnnotationKey is the key of the update source file annotation + UpdateSourceFileAnnotationKey = Group + "/update-source-file-time" +) + func (kb *KnowledgeBase) VectorStoreCollectionName() string { return kb.Namespace + "_" + kb.Name } diff --git a/controllers/app-node/chain/llmchain_controller.go b/controllers/app-node/chain/llmchain_controller.go index 4151b4b47..0d3779d26 100644 --- a/controllers/app-node/chain/llmchain_controller.go +++ b/controllers/app-node/chain/llmchain_controller.go @@ -18,6 +18,7 @@ package chain import ( "context" + "reflect" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" @@ -127,6 +128,9 @@ func (r *LLMChainReconciler) patchStatus(ctx context.Context, instance *api.LLMC if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { return err } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } patch := client.MergeFrom(latest.DeepCopy()) latest.Status = instance.Status return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("LLMChain-controller")) diff --git a/controllers/app-node/chain/retrieval_qa_chain_controller.go b/controllers/app-node/chain/retrieval_qa_chain_controller.go index a6fd74dd3..31382474e 100644 --- a/controllers/app-node/chain/retrieval_qa_chain_controller.go +++ b/controllers/app-node/chain/retrieval_qa_chain_controller.go @@ -18,6 +18,7 @@ package chain import ( "context" + "reflect" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" @@ -127,6 +128,9 @@ func (r *RetrievalQAChainReconciler) patchStatus(ctx context.Context, instance * if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { return err } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } patch := client.MergeFrom(latest.DeepCopy()) latest.Status = instance.Status return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("RetrievalQAChain-controller")) diff --git a/controllers/app-node/prompt/prompt_controller.go b/controllers/app-node/prompt/prompt_controller.go index 1f4cc5d7d..353cf1242 100644 --- a/controllers/app-node/prompt/prompt_controller.go +++ b/controllers/app-node/prompt/prompt_controller.go @@ -18,6 +18,7 @@ package chain import ( "context" + "reflect" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" @@ -125,6 +126,9 @@ func (r *PromptReconciler) patchStatus(ctx context.Context, instance *api.Prompt if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { return err } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } patch := client.MergeFrom(latest.DeepCopy()) latest.Status = instance.Status return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("Prompt-controller")) diff --git a/controllers/app-node/retriever/knowledgebase_retriever_controller.go b/controllers/app-node/retriever/knowledgebase_retriever_controller.go index 21931277b..93c175ded 100644 --- a/controllers/app-node/retriever/knowledgebase_retriever_controller.go +++ b/controllers/app-node/retriever/knowledgebase_retriever_controller.go @@ -18,6 +18,7 @@ package chain import ( "context" + "reflect" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" @@ -127,6 +128,9 @@ func (r *KnowledgeBaseRetrieverReconciler) patchStatus(ctx context.Context, inst if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { return err } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } patch := client.MergeFrom(latest.DeepCopy()) latest.Status = instance.Status return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("KnowledgeBaseRetriever-controller")) diff --git a/controllers/base/application_controller.go b/controllers/base/application_controller.go index d2071c55f..2e2f26b3a 100644 --- a/controllers/base/application_controller.go +++ b/controllers/base/application_controller.go @@ -18,6 +18,7 @@ package controllers import ( "context" + "reflect" "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" @@ -192,6 +193,9 @@ func (r *ApplicationReconciler) patchStatus(ctx context.Context, app *arcadiav1a if err := r.Client.Get(ctx, client.ObjectKeyFromObject(app), latest); err != nil { return err } + if reflect.DeepEqual(app.Status, latest.Status) { + return nil + } patch := client.MergeFrom(latest.DeepCopy()) latest.Status = app.Status return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("application-controller")) diff --git a/controllers/base/knowledgebase_controller.go b/controllers/base/knowledgebase_controller.go index d14da6d66..f2152d1f9 100644 --- a/controllers/base/knowledgebase_controller.go +++ b/controllers/base/knowledgebase_controller.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "path/filepath" + "reflect" "sync" "time" @@ -31,6 +32,7 @@ import ( "github.com/tmc/langchaingo/documentloaders" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/textsplitter" + corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -68,6 +70,8 @@ type KnowledgeBaseReconciler struct { Scheme *runtime.Scheme mu sync.Mutex HasHandledSuccessPath map[string]bool + readyMu sync.Mutex + ReadyMap map[string]bool } //+kubebuilder:rbac:groups=arcadia.kubeagi.k8s.com.cn,resources=knowledgebases,verbs=get;list;watch;create;update;patch;delete @@ -129,19 +133,29 @@ func (r *KnowledgeBaseReconciler) Reconcile(ctx context.Context, req ctrl.Reques kb, result, err = r.reconcile(ctx, log, kb) // Update status after reconciliation. - if updateStatusErr := r.patchStatus(ctx, kb); updateStatusErr != nil { + if updateStatusErr := r.patchStatus(ctx, log, kb); updateStatusErr != nil { log.Error(updateStatusErr, "unable to update status after reconciliation") return ctrl.Result{Requeue: true}, updateStatusErr } + log.V(5).Info("Reconcile done") return result, err } -func (r *KnowledgeBaseReconciler) patchStatus(ctx context.Context, kb *arcadiav1alpha1.KnowledgeBase) error { +func (r *KnowledgeBaseReconciler) patchStatus(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) error { latest := &arcadiav1alpha1.KnowledgeBase{} if err := r.Client.Get(ctx, client.ObjectKeyFromObject(kb), latest); err != nil { return err } + if reflect.DeepEqual(kb.Status, latest.Status) { + log.V(5).Info("status not changed, skip") + return nil + } + if r.isReady(kb) && !kb.Status.IsReady() { + log.V(5).Info("status is ready,but not get it from cluster, has cache, skip update status") + return nil + } + log.V(5).Info(fmt.Sprintf("try to patch status %#v", kb.Status)) patch := client.MergeFrom(latest.DeepCopy()) latest.Status = kb.Status return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("knowledgebase-controller")) @@ -155,17 +169,32 @@ func (r *KnowledgeBaseReconciler) SetupWithManager(mgr ctrl.Manager) error { } func (r *KnowledgeBaseReconciler) reconcile(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) (*arcadiav1alpha1.KnowledgeBase, ctrl.Result, error) { - // Observe generation change - if kb.Status.ObservedGeneration != kb.Generation { - kb.Status.ObservedGeneration = kb.Generation - kb = r.setCondition(kb, kb.InitCondition()) - if updateStatusErr := r.patchStatus(ctx, kb); updateStatusErr != nil { + // Observe generation change or manual update + if kb.Status.ObservedGeneration != kb.Generation || kb.Annotations[arcadiav1alpha1.UpdateSourceFileAnnotationKey] != "" { + r.cleanupHasHandledSuccessPath(kb) + if kb.Status.ObservedGeneration != kb.Generation { + log.Info("Generation changed") + kb.Status.ObservedGeneration = kb.Generation + } + kb = r.setCondition(log, kb, kb.InitCondition()) + if updateStatusErr := r.patchStatus(ctx, log, kb); updateStatusErr != nil { log.Error(updateStatusErr, "unable to update status after generation update") return kb, ctrl.Result{Requeue: true}, updateStatusErr } + if kb.Annotations[arcadiav1alpha1.UpdateSourceFileAnnotationKey] != "" { + log.Info("Manual update") + kbNew := kb.DeepCopy() + delete(kbNew.Annotations, arcadiav1alpha1.UpdateSourceFileAnnotationKey) + err := r.Patch(ctx, kbNew, client.MergeFrom(kb)) + if err != nil { + return kb, ctrl.Result{Requeue: true}, err + } + } + return kb, ctrl.Result{}, nil } - if kb.Status.IsReady() { + if kb.Status.IsReady() || r.isReady(kb) { + log.Info("KnowledgeBase is ready, skip reconcile") return kb, ctrl.Result{}, nil } @@ -173,27 +202,27 @@ func (r *KnowledgeBaseReconciler) reconcile(ctx context.Context, log logr.Logger vectorStoreReq := kb.Spec.VectorStore fileGroupsReq := kb.Spec.FileGroups if embedderReq == nil || vectorStoreReq == nil || len(fileGroupsReq) == 0 { - kb = r.setCondition(kb, kb.PendingCondition("embedder or vectorstore or filegroups is not setting")) + kb = r.setCondition(log, kb, kb.PendingCondition("embedder or vectorstore or filegroups is not setting")) return kb, ctrl.Result{}, nil } embedder := &arcadiav1alpha1.Embedder{} if err := r.Get(ctx, types.NamespacedName{Name: kb.Spec.Embedder.Name, Namespace: kb.Spec.Embedder.GetNamespace(kb.GetNamespace())}, embedder); err != nil { if apierrors.IsNotFound(err) { - kb = r.setCondition(kb, kb.PendingCondition("embedder is not found")) + kb = r.setCondition(log, kb, kb.PendingCondition("embedder is not found")) return kb, ctrl.Result{RequeueAfter: waitLonger}, nil } - kb = r.setCondition(kb, kb.ErrorCondition(err.Error())) + kb = r.setCondition(log, kb, kb.ErrorCondition(err.Error())) return kb, ctrl.Result{}, err } vectorStore := &arcadiav1alpha1.VectorStore{} if err := r.Get(ctx, types.NamespacedName{Name: kb.Spec.VectorStore.Name, Namespace: kb.Spec.VectorStore.GetNamespace(kb.GetNamespace())}, vectorStore); err != nil { if apierrors.IsNotFound(err) { - kb = r.setCondition(kb, kb.PendingCondition("vectorStore is not found")) + kb = r.setCondition(log, kb, kb.PendingCondition("vectorStore is not found")) return kb, ctrl.Result{RequeueAfter: waitLonger}, nil } - kb = r.setCondition(kb, kb.ErrorCondition(err.Error())) + kb = r.setCondition(log, kb, kb.ErrorCondition(err.Error())) return kb, ctrl.Result{}, err } @@ -205,24 +234,35 @@ func (r *KnowledgeBaseReconciler) reconcile(ctx context.Context, log logr.Logger } } if err := errors.Join(errs...); err != nil { - kb = r.setCondition(kb, kb.ErrorCondition(err.Error())) + kb = r.setCondition(log, kb, kb.ErrorCondition(err.Error())) return kb, ctrl.Result{RequeueAfter: waitMedium}, nil } else { for _, fileGroupDetail := range kb.Status.FileGroupDetail { for _, fileDetail := range fileGroupDetail.FileDetails { if fileDetail.Phase == arcadiav1alpha1.FileProcessPhaseFailed && fileDetail.ErrMessage != "" { - kb = r.setCondition(kb, kb.ErrorCondition(fileDetail.ErrMessage)) + kb = r.setCondition(log, kb, kb.ErrorCondition(fileDetail.ErrMessage)) return kb, ctrl.Result{RequeueAfter: waitMedium}, nil } } } - kb = r.setCondition(kb, kb.ReadyCondition()) + kb = r.setCondition(log, kb, kb.ReadyCondition()) } - return kb, ctrl.Result{}, nil } -func (r *KnowledgeBaseReconciler) setCondition(kb *arcadiav1alpha1.KnowledgeBase, condition ...arcadiav1alpha1.Condition) *arcadiav1alpha1.KnowledgeBase { +func (r *KnowledgeBaseReconciler) setCondition(log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase, condition ...arcadiav1alpha1.Condition) *arcadiav1alpha1.KnowledgeBase { + ready := false + for _, c := range condition { + if c.Type == arcadiav1alpha1.TypeReady && c.Status == corev1.ConditionTrue { + ready = true + break + } + } + if ready { + r.ready(log, kb) + } else { + r.unready(log, kb) + } kb.Status.SetConditions(condition...) return kb } @@ -270,6 +310,7 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo // brand new knowledgebase, init status. kb.Status.FileGroupDetail = make([]arcadiav1alpha1.FileGroupDetail, 1) kb.Status.FileGroupDetail[0].Init(group) + log.V(5).Info("init filegroupdetail status") } var fileGroupDetail *arcadiav1alpha1.FileGroupDetail pathMap := make(map[string]*arcadiav1alpha1.FileDetails, 1) @@ -284,6 +325,7 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo } if fileGroupDetail == nil { // this group is newly added + log.V(5).Info("new added group, init filegroupdetail status") fileGroupDetail = &arcadiav1alpha1.FileGroupDetail{} fileGroupDetail.Init(group) kb.Status.FileGroupDetail = append(kb.Status.FileGroupDetail, *fileGroupDetail) @@ -376,6 +418,7 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo r.HasHandledSuccessPath[r.hasHandledPathKey(kb, group, path)] = true r.mu.Unlock() fileDetail.UpdateErr(nil, arcadiav1alpha1.FileProcessPhaseSucceeded) + log.Info("handle FileGroup succeeded") } return errors.Join(errs...) } @@ -458,38 +501,18 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge return err } } - for i, doc := range documents { - log.V(5).Info(fmt.Sprintf("document[%d]: embedding:%s, metadata:%v", i, doc.PageContent, doc.Metadata)) - } - s, finish, err := vectorstore.NewVectorStore(ctx, store, em, kb.VectorStoreCollectionName(), r.Client, nil) - if err != nil { - return err - } - log.Info("handle file: add documents to embedder") - if _, err = s.AddDocuments(ctx, documents); err != nil { - return err - } - if finish != nil { - finish() - } - log.Info("handle file succeeded") - return nil + return vectorstore.AddDocuments(ctx, log, store, em, kb.VectorStoreCollectionName(), r.Client, nil, documents) } func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) { - r.mu.Lock() - for _, fg := range kb.Spec.FileGroups { - for _, path := range fg.Paths { - delete(r.HasHandledSuccessPath, r.hasHandledPathKey(kb, fg, path)) - } - } - r.mu.Unlock() + r.cleanupHasHandledSuccessPath(kb) + r.unready(log, kb) vectorStore := &arcadiav1alpha1.VectorStore{} if err := r.Get(ctx, types.NamespacedName{Name: kb.Spec.VectorStore.Name, Namespace: kb.Spec.VectorStore.GetNamespace(kb.GetNamespace())}, vectorStore); err != nil { log.Error(err, "reconcile delete: get vector store error, may leave garbage data") return } - _ = vectorstore.RemoveCollection(ctx, log, vectorStore, kb.VectorStoreCollectionName()) + _ = vectorstore.RemoveCollection(ctx, log, vectorStore, kb.VectorStoreCollectionName(), r.Client, nil) } func (r *KnowledgeBaseReconciler) hasHandledPathKey(kb *arcadiav1alpha1.KnowledgeBase, filegroup arcadiav1alpha1.FileGroup, path string) string { @@ -499,3 +522,32 @@ func (r *KnowledgeBaseReconciler) hasHandledPathKey(kb *arcadiav1alpha1.Knowledg } return kb.Name + "/" + kb.Namespace + "/" + sourceName + "/" + path } + +func (r *KnowledgeBaseReconciler) cleanupHasHandledSuccessPath(kb *arcadiav1alpha1.KnowledgeBase) { + r.mu.Lock() + for _, fg := range kb.Spec.FileGroups { + for _, path := range fg.Paths { + delete(r.HasHandledSuccessPath, r.hasHandledPathKey(kb, fg, path)) + } + } + r.mu.Unlock() +} + +func (r *KnowledgeBaseReconciler) ready(log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) { + r.readyMu.Lock() + defer r.readyMu.Unlock() + log.V(5).Info("ready") + r.ReadyMap[string(kb.GetUID())] = true +} + +func (r *KnowledgeBaseReconciler) unready(log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) { + r.readyMu.Lock() + defer r.readyMu.Unlock() + log.V(5).Info("unready") + delete(r.ReadyMap, string(kb.GetUID())) +} + +func (r *KnowledgeBaseReconciler) isReady(kb *arcadiav1alpha1.KnowledgeBase) bool { + v, ok := r.ReadyMap[string(kb.GetUID())] + return ok && v +} diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 6b0700271..d04dbe5ba 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.2.11 +version: 0.2.12 appVersion: "0.1.0" keywords: diff --git a/deploy/charts/arcadia/templates/config.yaml b/deploy/charts/arcadia/templates/config.yaml index fe542f7a5..e2919746a 100644 --- a/deploy/charts/arcadia/templates/config.yaml +++ b/deploy/charts/arcadia/templates/config.yaml @@ -1,4 +1,3 @@ -{{- if .Values.postgresql.enabled }} apiVersion: v1 data: config: | @@ -25,7 +24,12 @@ data: vectorStore: apiGroup: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: VectorStore +{{- if and (.Values.chromadb.enabled) (eq .Values.global.defaultVectorStoreType "chroma") }} name: '{{ .Release.Name }}-vectorstore' +{{- end }} +{{- if and (.Values.postgresql.enabled) (eq .Values.global.defaultVectorStoreType "pgvector") }} + name: '{{ .Release.Name }}-pgvector-vectorstore' +{{- end }} namespace: '{{ .Release.Namespace }}' #streamlit: @@ -36,16 +40,17 @@ data: dataprocess: | llm: qa_retry_count: {{ .Values.dataprocess.config.llm.qa_retry_count }} +{{- if .Values.postgresql.enabled }} postgresql: host: {{ .Release.Name }}-postgresql.{{ .Release.Namespace }}.svc.cluster.local port: {{ .Values.postgresql.containerPorts.postgresql }} user: {{ .Values.postgresql.global.postgresql.auth.username }} password: {{ .Values.postgresql.global.postgresql.auth.password }} database: {{ .Values.postgresql.global.postgresql.auth.database }} +{{- end }} kind: ConfigMap metadata: labels: control-plane: {{ .Release.Name }}-arcadia name: {{ .Release.Name }}-config namespace: {{ .Release.Namespace }} -{{- end }} \ No newline at end of file diff --git a/deploy/charts/arcadia/templates/post-vectorstore.yaml b/deploy/charts/arcadia/templates/post-vectorstore.yaml index 7d3905e0f..9beb3531f 100644 --- a/deploy/charts/arcadia/templates/post-vectorstore.yaml +++ b/deploy/charts/arcadia/templates/post-vectorstore.yaml @@ -1,3 +1,4 @@ +{{- if .Values.chromadb.enabled }} apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: VectorStore metadata: @@ -13,6 +14,7 @@ spec: url: 'http://{{ .Release.Name }}-chromadb.{{ .Release.Namespace }}.svc.cluster.local:{{ .Values.chromadb.chromadb.serverHttpPort }}' chroma: distanceFunction: cosine +{{- end }} {{- if .Values.postgresql.enabled }} --- diff --git a/deploy/charts/arcadia/values.yaml b/deploy/charts/arcadia/values.yaml index 504bfb628..3ca6a2651 100644 --- a/deploy/charts/arcadia/values.yaml +++ b/deploy/charts/arcadia/values.yaml @@ -1,6 +1,10 @@ global: oss: bucket: &default-oss-bucket "arcadia" + ## @param global.defaultVectorStoreType Defines the default vector database type, currently `chroma` and `pgvector` are available + ## When the option is `chroma`, it needs `chromadb.enabled` to be `true` as well to work. + ## When the option is `pgvector`, it needs `postgresql.enabled` to be `true` as well to work. + defaultVectorStoreType: pgvector # @section controller is used as the core controller for arcadia # @param image Image to be used @@ -100,7 +104,7 @@ minio: # @section chromadb is used to deploy a chromadb instance chromadb: - enabled: true + enabled: false image: repository: kubeagi/chromadb chromadb: diff --git a/main.go b/main.go index 88b6c9da1..1c5478a9e 100644 --- a/main.go +++ b/main.go @@ -209,7 +209,8 @@ func main() { if err = (&basecontrollers.KnowledgeBaseReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), - HasHandledSuccessPath: make(map[string]bool, 0), + HasHandledSuccessPath: make(map[string]bool), + ReadyMap: make(map[string]bool), }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "KnowledgeBase") os.Exit(1) diff --git a/pkg/llms/zhipuai/api.go b/pkg/llms/zhipuai/api.go index 7539cec6b..d04473559 100644 --- a/pkg/llms/zhipuai/api.go +++ b/pkg/llms/zhipuai/api.go @@ -209,7 +209,9 @@ func (z *ZhiPuAI) CreateEmbedding(ctx context.Context, inputTexts []string) ([][ success = true } } - + if postResponse == nil { + return nil, errors.New("max retry reached, embedding post failed") + } if !postResponse.Success { return nil, fmt.Errorf("embedding post failed:\n%s", postResponse.String()) } diff --git a/pkg/vectorstore/pgvector.go b/pkg/vectorstore/pgvector.go new file mode 100644 index 000000000..2c3600f6d --- /dev/null +++ b/pkg/vectorstore/pgvector.go @@ -0,0 +1,195 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vectorstore + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-logr/logr" + "github.com/jackc/pgx/v5" + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms/openai" + lanchaingoschema "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" + "github.com/tmc/langchaingo/vectorstores/pgvector" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/dynamic" + "k8s.io/klog/v2" + "sigs.k8s.io/controller-runtime/pkg/client" + + arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/pkg/datasource" + "github.com/kubeagi/arcadia/pkg/utils" +) + +var _ vectorstores.VectorStore = (*PGVectorStore)(nil) + +type PGVectorStore struct { + *pgx.Conn + pgvector.Store + *arcadiav1alpha1.PGVector +} + +func NewPGVectorStore(ctx context.Context, vs *arcadiav1alpha1.VectorStore, c client.Client, dc dynamic.Interface, embedder embeddings.Embedder, collectionName string) (v *PGVectorStore, finish func(), err error) { + v = &PGVectorStore{PGVector: vs.Spec.PGVector} + ops := []pgvector.Option{ + pgvector.WithPreDeleteCollection(vs.Spec.PGVector.PreDeleteCollection), + } + if vs.Spec.PGVector.CollectionTableName != "" { + ops = append(ops, pgvector.WithCollectionTableName(vs.Spec.PGVector.CollectionTableName)) + } else { + v.PGVector.CollectionTableName = pgvector.DefaultCollectionStoreTableName + } + if vs.Spec.PGVector.EmbeddingTableName != "" { + ops = append(ops, pgvector.WithEmbeddingTableName(vs.Spec.PGVector.EmbeddingTableName)) + } else { + v.PGVector.EmbeddingTableName = pgvector.DefaultEmbeddingStoreTableName + } + if ref := vs.Spec.PGVector.DataSourceRef; ref != nil { + if err := utils.ValidateClient(c, dc); err != nil { + return nil, nil, err + } + ds := &arcadiav1alpha1.Datasource{} + if c != nil { + if err := c.Get(ctx, types.NamespacedName{Name: ref.Name, Namespace: ref.GetNamespace(vs.GetNamespace())}, ds); err != nil { + return nil, nil, err + } + } else { + obj, err := dc.Resource(schema.GroupVersionResource{Group: "arcadia.kubeagi.k8s.com.cn", Version: "v1alpha1", Resource: "datasources"}). + Namespace(ref.GetNamespace(vs.GetNamespace())).Get(ctx, ref.Name, metav1.GetOptions{}) + if err != nil { + return nil, nil, err + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), ds) + if err != nil { + return nil, nil, err + } + } + vs.Spec.Endpoint = ds.Spec.Endpoint.DeepCopy() + pool, err := datasource.GetPostgreSQLPool(ctx, c, dc, ds) + if err != nil { + return nil, nil, err + } + conn, err := pool.Acquire(ctx) + if err != nil { + return nil, nil, err + } + klog.V(5).Info("acquire pg conn from pool") + finish = func() { + if conn != nil { + conn.Release() + klog.V(5).Info("release pg conn to pool") + } + } + v.Conn = conn.Conn() + ops = append(ops, pgvector.WithConn(v.Conn)) + } else { + conn, err := pgx.Connect(ctx, vs.Spec.Endpoint.URL) + if err != nil { + return nil, nil, err + } + v.Conn = conn + ops = append(ops, pgvector.WithConn(conn)) + } + if embedder != nil { + ops = append(ops, pgvector.WithEmbedder(embedder)) + } else { + llm, _ := openai.New() + embedder, _ = embeddings.NewEmbedder(llm) + } + ops = append(ops, pgvector.WithEmbedder(embedder)) + if collectionName != "" { + ops = append(ops, pgvector.WithCollectionName(collectionName)) + v.PGVector.CollectionName = collectionName + } else { + ops = append(ops, pgvector.WithCollectionName(vs.Spec.PGVector.CollectionName)) + } + store, err := pgvector.New(ctx, ops...) + if err != nil { + return nil, nil, err + } + v.Store = store + return v, finish, nil +} + +// RemoveExist remove exist document from pgvector +// Note: it is currently assumed that the embedder of a knowledge base is constant that means the result of embedding a fixed document is fixed, +// disregarding the case where the embedder changes (and if it does, a lot of processing will need to be done in many places, not just here) +func (s *PGVectorStore) RemoveExist(ctx context.Context, log logr.Logger, document []lanchaingoschema.Document) (doc []lanchaingoschema.Document, err error) { + // get collection_uuid from collection_table, if null, means no exits + collectionUUID := "" + sql := fmt.Sprintf(`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1`, s.PGVector.CollectionTableName) + err = s.Conn.QueryRow(ctx, sql, s.PGVector.CollectionName).Scan(&collectionUUID) + if collectionUUID == "" { + return document, err + } + in := make([]string, 0) + for _, d := range document { + in = append(in, d.PageContent) + } + sql = fmt.Sprintf(`SELECT document, cmetadata FROM %s WHERE collection_id = $1 AND document in ('%s')`, s.PGVector.EmbeddingTableName, strings.Join(in, "', '")) + rows, err := s.Conn.Query(ctx, sql, collectionUUID) + if err != nil { + return nil, err + } + res := make(map[string]lanchaingoschema.Document, 0) + for rows.Next() { + doc := lanchaingoschema.Document{} + if err := rows.Scan(&doc.PageContent, &doc.Metadata); err != nil { + return nil, err + } + res[doc.PageContent] = doc + } + if len(res) == 0 { + return document, nil + } + if len(res) == len(document) { + return nil, nil + } + for page := range res { + log.V(5).Info(fmt.Sprintf("filter out exist documents[%s]", page)) + } + doc = make([]lanchaingoschema.Document, 0, len(document)) + for _, d := range document { + has, ok := res[d.PageContent] + if ok { + // The value returned from the database is of type float64, + // but the original document is of type int + if v, ok := has.Metadata["lineNumber"]; ok { + has.Metadata["lineNumber"] = int(v.(float64)) + } + if reflect.DeepEqual(has.Metadata, d.Metadata) { + continue + } + log.V(5).Info(fmt.Sprintf("exist document, same page content:%s, raw metadata:%v has metadata:%v", d.PageContent, d.Metadata, has.Metadata)) + for k, v := range d.Metadata { + hasV := has.Metadata[k] + if !reflect.DeepEqual(v, hasV) { + log.V(5).Info(fmt.Sprintf("different metadata: raw:[%T]%v has:[%T]%v", v, v, hasV, hasV)) + } + } + } + doc = append(doc, d) + } + return doc, nil +} diff --git a/pkg/vectorstore/vectorstore.go b/pkg/vectorstore/vectorstore.go index 4aaaa368f..825da99c4 100644 --- a/pkg/vectorstore/vectorstore.go +++ b/pkg/vectorstore/vectorstore.go @@ -19,24 +19,17 @@ package vectorstore import ( "context" "errors" + "fmt" "github.com/go-logr/logr" "github.com/tmc/langchaingo/embeddings" - "github.com/tmc/langchaingo/llms/openai" + lanchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores/chroma" - "github.com/tmc/langchaingo/vectorstores/pgvector" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/dynamic" - "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/datasource" - "github.com/kubeagi/arcadia/pkg/utils" ) var ( @@ -60,68 +53,7 @@ func NewVectorStore(ctx context.Context, vs *arcadiav1alpha1.VectorStore, embedd } v, err = chroma.New(ops...) case arcadiav1alpha1.VectorStoreTypePGVector: - ops := []pgvector.Option{ - pgvector.WithPreDeleteCollection(vs.Spec.PGVector.PreDeleteCollection), - } - if vs.Spec.PGVector.CollectionTableName != "" { - ops = append(ops, pgvector.WithCollectionTableName(vs.Spec.PGVector.CollectionTableName)) - } - if vs.Spec.PGVector.EmbeddingTableName != "" { - ops = append(ops, pgvector.WithEmbeddingTableName(vs.Spec.PGVector.EmbeddingTableName)) - } - if ref := vs.Spec.PGVector.DataSourceRef; ref != nil { - if err := utils.ValidateClient(c, dc); err != nil { - return nil, nil, err - } - ds := &arcadiav1alpha1.Datasource{} - if c != nil { - if err := c.Get(ctx, types.NamespacedName{Name: ref.Name, Namespace: ref.GetNamespace(vs.GetNamespace())}, ds); err != nil { - return nil, nil, err - } - } else { - obj, err := dc.Resource(schema.GroupVersionResource{Group: "arcadia.kubeagi.k8s.com.cn", Version: "v1alpha1", Resource: "datasources"}). - Namespace(ref.GetNamespace(vs.GetNamespace())).Get(ctx, ref.Name, metav1.GetOptions{}) - if err != nil { - return nil, nil, err - } - err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), ds) - if err != nil { - return nil, nil, err - } - } - vs.Spec.Endpoint = ds.Spec.Endpoint.DeepCopy() - pool, err := datasource.GetPostgreSQLPool(ctx, c, dc, ds) - if err != nil { - return nil, nil, err - } - conn, err := pool.Acquire(ctx) - if err != nil { - return nil, nil, err - } - klog.V(5).Info("acquire pg conn from pool") - finish = func() { - if conn != nil { - conn.Release() - klog.V(5).Info("release pg conn to pool") - } - } - ops = append(ops, pgvector.WithConn(conn.Conn())) - } else { - ops = append(ops, pgvector.WithConnectionURL(vs.Spec.Endpoint.URL)) - } - if embedder != nil { - ops = append(ops, pgvector.WithEmbedder(embedder)) - } else { - llm, _ := openai.New() - embedder, _ = embeddings.NewEmbedder(llm) - } - ops = append(ops, pgvector.WithEmbedder(embedder)) - if collectionName != "" { - ops = append(ops, pgvector.WithCollectionName(collectionName)) - } else { - ops = append(ops, pgvector.WithCollectionName(vs.Spec.PGVector.CollectionName)) - } - v, err = pgvector.New(ctx, ops...) + v, finish, err = NewPGVectorStore(ctx, vs, c, dc, embedder, collectionName) case arcadiav1alpha1.VectorStoreTypeUnknown: fallthrough default: @@ -130,7 +62,7 @@ func NewVectorStore(ctx context.Context, vs *arcadiav1alpha1.VectorStore, embedd return v, finish, err } -func RemoveCollection(ctx context.Context, log logr.Logger, vs *arcadiav1alpha1.VectorStore, collectionName string) (err error) { +func RemoveCollection(ctx context.Context, log logr.Logger, vs *arcadiav1alpha1.VectorStore, collectionName string, c client.Client, dc dynamic.Interface) (err error) { switch vs.Spec.Type() { case arcadiav1alpha1.VectorStoreTypeChroma: ops := []chroma.Option{ @@ -151,19 +83,14 @@ func RemoveCollection(ctx context.Context, log logr.Logger, vs *arcadiav1alpha1. return err } case arcadiav1alpha1.VectorStoreTypePGVector: - ops := []pgvector.Option{ - pgvector.WithConnectionURL(vs.Spec.Endpoint.URL), - pgvector.WithPreDeleteCollection(vs.Spec.PGVector.PreDeleteCollection), - pgvector.WithCollectionTableName(vs.Spec.PGVector.CollectionTableName), - } - if collectionName != "" { - ops = append(ops, pgvector.WithCollectionName(collectionName)) - } else { - ops = append(ops, pgvector.WithCollectionName(vs.Spec.PGVector.CollectionName)) - } - v, err := pgvector.New(ctx, ops...) + v, finish, err := NewPGVectorStore(ctx, vs, c, dc, nil, collectionName) + defer func() { + if finish != nil { + finish() + } + }() if err != nil { - log.Error(err, "reconcile delete: init vector store error, may leave garbage data") + log.Error(err, "reconcile delete: init pgvector error, may leave garbage data") return err } if err = v.RemoveCollection(ctx); err != nil { @@ -178,3 +105,29 @@ func RemoveCollection(ctx context.Context, log logr.Logger, vs *arcadiav1alpha1. } return err } + +func AddDocuments(ctx context.Context, log logr.Logger, vs *arcadiav1alpha1.VectorStore, embedder embeddings.Embedder, collectionName string, c client.Client, dc dynamic.Interface, documents []lanchaingoschema.Document) (err error) { + s, finish, err := NewVectorStore(ctx, vs, embedder, collectionName, c, dc) + if err != nil { + return err + } + log.Info("handle file: add documents to embedder") + if store, ok := s.(*PGVectorStore); ok { + // now only pgvector support Row-level updates + log.Info("handle file: use pgvector, filter out exist documents") + if documents, err = store.RemoveExist(ctx, log, documents); err != nil { + return err + } + } + for i, doc := range documents { + log.V(5).Info(fmt.Sprintf("add doc to vectorstore, document[%d]: embedding:%s, metadata:%v", i, doc.PageContent, doc.Metadata)) + } + if _, err = s.AddDocuments(ctx, documents); err != nil { + return err + } + if finish != nil { + finish() + } + log.Info("handle file succeeded") + return nil +} diff --git a/tests/deploy-values.yaml b/tests/deploy-values.yaml index f591f58e1..27c1fb32a 100644 --- a/tests/deploy-values.yaml +++ b/tests/deploy-values.yaml @@ -1,3 +1,5 @@ +global: + defaultVectorStoreType: pgvector # @section controller is used as the core controller for arcadia # @param image Image to be used # @param imagePullPolcy ImagePullPolicy @@ -81,6 +83,7 @@ minio: # @section chromadb is used to deploy a chromadb instance chromadb: + enabled: true image: repository: kubeagi/chromadb chromadb: diff --git a/tests/example-test.sh b/tests/example-test.sh index 100af1d31..b6b10c271 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -158,6 +158,13 @@ function waitCRDStatusReady() { message=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].message') if [[ $readStatus == "True" ]]; then info $message + if [[ ${source} == "KnowledgeBase" ]]; then + fileStatus=$(kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails[0].phase') + if [[ $fileStatus != "Succeeded" ]]; then + kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails' + exit 1 + fi + fi break fi @@ -243,7 +250,7 @@ if [[ $datasourceType != "postgresql" ]]; then fi info "6. verify default vectorstore" -waitCRDStatusReady "VectorStore" "arcadia" "arcadia-vectorstore" +waitCRDStatusReady "VectorStore" "arcadia" "arcadia-pgvector-vectorstore" info "6.2 verify PGVector vectorstore" kubectl apply -f config/samples/arcadia_v1alpha1_vectorstore_pgvector.yaml waitCRDStatusReady "VectorStore" "arcadia" "pgvector-sample" @@ -281,7 +288,8 @@ info "7.4.2 create knowledgebase based on pgvector and wait it ready" kubectl apply -f config/samples/arcadia_v1alpha1_knowledgebase_pgvector.yaml waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample-pgvector" -info "7.5 check chroma vectorstore has data" +info "7.5 check vectorstore has data" +info "7.5.1 check chroma vectorstore has data" kubectl port-forward -n arcadia svc/arcadia-chromadb 8000:8000 >/dev/null 2>&1 & chroma_pid=$! info "port-forward chroma in pid: $chroma_pid" @@ -295,6 +303,49 @@ else exit 1 fi +info "7.5.2 check pgvector vectorstore has data" +kubectl port-forward -n arcadia svc/arcadia-postgresql 5432:5432 >/dev/null 2>&1 & +postgres_pid=$! +info "port-forward postgres in pid: $chroma_pid" +sleep 3 +paasword=$(kubectl get secrets -n arcadia arcadia-postgresql -o json | jq -r '.data."postgres-password"' | base64 --decode) +if [[ $GITHUB_ACTIONS == "true" ]]; then + docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h localhost -c "select document from langchain_pg_embedding;" + pgdata=$(docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h localhost -c "select document from langchain_pg_embedding;") +else + docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h host.docker.internal -c "select document from langchain_pg_embedding;" + pgdata=$(docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h host.docker.internal -c "select document from langchain_pg_embedding;") +fi +if [[ -z $pgdata ]]; then + info "get no data in postgres" + exit 1 +fi + +info "7.6 update qa.csv to make sure it can be embedding" +echo "newquestion,newanswer" >>pkg/documentloaders/testdata/qa.csv +mc cp pkg/documentloaders/testdata/qa.csv arcadiatest/${bucket}/dataset/dataset-playground/v1/qa.csv +mc tag set arcadiatest/${bucket}/dataset/dataset-playground/v1/qa.csv "object_type=QA" +sleep 3 +kubectl annotate knowledgebase/knowledgebase-sample-pgvector -n arcadia "arcadia.kubeagi.k8s.com.cn/update-source-file-time=$(date)" +sleep 3 +waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample-pgvector" +if [[ $GITHUB_ACTIONS == "true" ]]; then + docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h localhost -c "select document from langchain_pg_embedding;" + pgdata=$(docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h localhost -c "select document from langchain_pg_embedding;") +else + docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h host.docker.internal -c "select document from langchain_pg_embedding;" + pgdata=$(docker run --net=host --entrypoint="" -e PGPASSWORD=$paasword kubeagi/postgresql:latest psql -U postgres -d arcadia -h host.docker.internal -c "select document from langchain_pg_embedding;") +fi +if [[ -z $pgdata ]]; then + info "get no data in postgres" + exit 1 +else + if [[ ! $pgdata =~ "newquestion" ]]; then + info "get no new data in postgres" + exit 1 + fi +fi + info "8 validate simple app can work normally" info "Prepare dependent LLM service" kubectl apply -f config/samples/app_shared_llm_service.yaml