Skip to content

Commit

Permalink
feat: knowledgebase support more granular update
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <[email protected]>
  • Loading branch information
Abirdcfly committed Jan 12, 2024
1 parent d828a3d commit 48f3401
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 111 deletions.
5 changes: 5 additions & 0 deletions api/base/v1alpha1/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

const (
// UpdateSourceFileAnnotationKey is the key of the update source file annotation
UpdateSourceFileAnnotationKey = Group + "/update-source-file-time"
)

func (kb *KnowledgeBase) VectorStoreCollectionName() string {
return kb.Namespace + "_" + kb.Name
}
Expand Down
53 changes: 29 additions & 24 deletions controllers/base/knowledgebase_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,27 @@ func (r *KnowledgeBaseReconciler) SetupWithManager(mgr ctrl.Manager) error {
}

func (r *KnowledgeBaseReconciler) reconcile(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) (*arcadiav1alpha1.KnowledgeBase, ctrl.Result, error) {
// Observe generation change
if kb.Status.ObservedGeneration != kb.Generation {
kb.Status.ObservedGeneration = kb.Generation
// Observe generation change or manual update
if kb.Status.ObservedGeneration != kb.Generation || kb.Annotations[arcadiav1alpha1.UpdateSourceFileAnnotationKey] != "" {
if kb.Status.ObservedGeneration != kb.Generation {
log.Info("Generation changed")
kb.Status.ObservedGeneration = kb.Generation
}
kb = r.setCondition(kb, kb.InitCondition())
if updateStatusErr := r.patchStatus(ctx, kb); updateStatusErr != nil {
log.Error(updateStatusErr, "unable to update status after generation update")
return kb, ctrl.Result{Requeue: true}, updateStatusErr
}
if kb.Annotations[arcadiav1alpha1.UpdateSourceFileAnnotationKey] != "" {
log.Info("Manual update")
kbNew := kb.DeepCopy()
delete(kbNew.Annotations, arcadiav1alpha1.UpdateSourceFileAnnotationKey)
err := r.Patch(ctx, kbNew, client.MergeFrom(kb))
if err != nil {
return kb, ctrl.Result{Requeue: true}, err
}
}
r.cleanupHasHandledSuccessPath(kb)
}

if kb.Status.IsReady() {
Expand Down Expand Up @@ -461,35 +474,17 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
for i, doc := range documents {
log.V(5).Info(fmt.Sprintf("document[%d]: embedding:%s, metadata:%v", i, doc.PageContent, doc.Metadata))
}
s, finish, err := vectorstore.NewVectorStore(ctx, store, em, kb.VectorStoreCollectionName(), r.Client, nil)
if err != nil {
return err
}
log.Info("handle file: add documents to embedder")
if _, err = s.AddDocuments(ctx, documents); err != nil {
return err
}
if finish != nil {
finish()
}
log.Info("handle file succeeded")
return nil
return vectorstore.AddDocuments(ctx, log, store, em, kb.VectorStoreCollectionName(), r.Client, nil, documents)
}

func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) {
r.mu.Lock()
for _, fg := range kb.Spec.FileGroups {
for _, path := range fg.Paths {
delete(r.HasHandledSuccessPath, r.hasHandledPathKey(kb, fg, path))
}
}
r.mu.Unlock()
r.cleanupHasHandledSuccessPath(kb)
vectorStore := &arcadiav1alpha1.VectorStore{}
if err := r.Get(ctx, types.NamespacedName{Name: kb.Spec.VectorStore.Name, Namespace: kb.Spec.VectorStore.GetNamespace(kb.GetNamespace())}, vectorStore); err != nil {
log.Error(err, "reconcile delete: get vector store error, may leave garbage data")
return
}
_ = vectorstore.RemoveCollection(ctx, log, vectorStore, kb.VectorStoreCollectionName())
_ = vectorstore.RemoveCollection(ctx, log, vectorStore, kb.VectorStoreCollectionName(), r.Client, nil)
}

func (r *KnowledgeBaseReconciler) hasHandledPathKey(kb *arcadiav1alpha1.KnowledgeBase, filegroup arcadiav1alpha1.FileGroup, path string) string {
Expand All @@ -499,3 +494,13 @@ func (r *KnowledgeBaseReconciler) hasHandledPathKey(kb *arcadiav1alpha1.Knowledg
}
return kb.Name + "/" + kb.Namespace + "/" + sourceName + "/" + path
}

func (r *KnowledgeBaseReconciler) cleanupHasHandledSuccessPath(kb *arcadiav1alpha1.KnowledgeBase) {
r.mu.Lock()
for _, fg := range kb.Spec.FileGroups {
for _, path := range fg.Paths {
delete(r.HasHandledSuccessPath, r.hasHandledPathKey(kb, fg, path))
}
}
r.mu.Unlock()
}
9 changes: 7 additions & 2 deletions deploy/charts/arcadia/templates/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
{{- if .Values.postgresql.enabled }}
apiVersion: v1
data:
config: |
Expand All @@ -25,7 +24,12 @@ data:
vectorStore:
apiGroup: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: VectorStore
{{- if and (.Values.chromadb.enabled) (eq .Values.global.defaultVectorStoreType "chroma") }}
name: '{{ .Release.Name }}-vectorstore'
{{- end }}
{{- if and (.Values.postgresql.enabled) (eq .Values.global.defaultVectorStoreType "pgvector") }}
name: '{{ .Release.Name }}-pgvector-vectorstore'
{{- end }}
namespace: '{{ .Release.Namespace }}'

#streamlit:
Expand All @@ -36,16 +40,17 @@ data:
dataprocess: |
llm:
qa_retry_count: {{ .Values.dataprocess.config.llm.qa_retry_count }}
{{- if .Values.postgresql.enabled }}
postgresql:
host: {{ .Release.Name }}-postgresql.{{ .Release.Namespace }}.svc.cluster.local
port: {{ .Values.postgresql.containerPorts.postgresql }}
user: {{ .Values.postgresql.global.postgresql.auth.username }}
password: {{ .Values.postgresql.global.postgresql.auth.password }}
database: {{ .Values.postgresql.global.postgresql.auth.database }}
{{- end }}
kind: ConfigMap
metadata:
labels:
control-plane: {{ .Release.Name }}-arcadia
name: {{ .Release.Name }}-config
namespace: {{ .Release.Namespace }}
{{- end }}
2 changes: 2 additions & 0 deletions deploy/charts/arcadia/templates/post-vectorstore.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{{- if .Values.chromadb.enabled }}
apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: VectorStore
metadata:
Expand All @@ -13,6 +14,7 @@ spec:
url: 'http://{{ .Release.Name }}-chromadb.{{ .Release.Namespace }}.svc.cluster.local:{{ .Values.chromadb.chromadb.serverHttpPort }}'
chroma:
distanceFunction: cosine
{{- end }}

{{- if .Values.postgresql.enabled }}
---
Expand Down
4 changes: 4 additions & 0 deletions deploy/charts/arcadia/values.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
global:
oss:
bucket: &default-oss-bucket "arcadia"
## @param global.defaultVectorStoreType Defines the default vector database type, currently `chroma` and `pgvector` are available
## When the option is `chroma`, it needs `chromadb.enabled` to be `true` as well to work.
## When the option is `pgvector`, it needs `postgresql.enabled` to be `true` as well to work.
defaultVectorStoreType: pgvector

# @section controller is used as the core controller for arcadia
# @param image Image to be used
Expand Down
176 changes: 176 additions & 0 deletions pkg/vectorstore/pgvector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
Copyright 2024 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package vectorstore

import (
"context"
"fmt"
"reflect"
"strings"

"github.com/jackc/pgx/v5"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/openai"
lanchaingoschema "github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/dynamic"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/datasource"
"github.com/kubeagi/arcadia/pkg/utils"
)

var _ vectorstores.VectorStore = (*PGVectorStore)(nil)

type PGVectorStore struct {
*pgx.Conn
pgvector.Store
*arcadiav1alpha1.PGVector
}

func NewPGVectorStore(ctx context.Context, vs *arcadiav1alpha1.VectorStore, c client.Client, dc dynamic.Interface, embedder embeddings.Embedder, collectionName string) (v *PGVectorStore, finish func(), err error) {
v = &PGVectorStore{PGVector: vs.Spec.PGVector}
ops := []pgvector.Option{
pgvector.WithPreDeleteCollection(vs.Spec.PGVector.PreDeleteCollection),
}
if vs.Spec.PGVector.CollectionTableName != "" {
ops = append(ops, pgvector.WithCollectionTableName(vs.Spec.PGVector.CollectionTableName))
} else {
v.PGVector.CollectionTableName = pgvector.DefaultCollectionStoreTableName
}
if vs.Spec.PGVector.EmbeddingTableName != "" {
ops = append(ops, pgvector.WithEmbeddingTableName(vs.Spec.PGVector.EmbeddingTableName))
} else {
v.PGVector.EmbeddingTableName = pgvector.DefaultEmbeddingStoreTableName
}
if ref := vs.Spec.PGVector.DataSourceRef; ref != nil {
if err := utils.ValidateClient(c, dc); err != nil {
return nil, nil, err
}
ds := &arcadiav1alpha1.Datasource{}
if c != nil {
if err := c.Get(ctx, types.NamespacedName{Name: ref.Name, Namespace: ref.GetNamespace(vs.GetNamespace())}, ds); err != nil {
return nil, nil, err
}
} else {
obj, err := dc.Resource(schema.GroupVersionResource{Group: "arcadia.kubeagi.k8s.com.cn", Version: "v1alpha1", Resource: "datasources"}).
Namespace(ref.GetNamespace(vs.GetNamespace())).Get(ctx, ref.Name, metav1.GetOptions{})
if err != nil {
return nil, nil, err
}
err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), ds)
if err != nil {
return nil, nil, err
}
}
vs.Spec.Endpoint = ds.Spec.Endpoint.DeepCopy()
pool, err := datasource.GetPostgreSQLPool(ctx, c, dc, ds)
if err != nil {
return nil, nil, err
}
conn, err := pool.Acquire(ctx)
if err != nil {
return nil, nil, err
}
klog.V(5).Info("acquire pg conn from pool")
finish = func() {
if conn != nil {
conn.Release()
klog.V(5).Info("release pg conn to pool")
}
}
v.Conn = conn.Conn()
ops = append(ops, pgvector.WithConn(v.Conn))
} else {
conn, err := pgx.Connect(ctx, vs.Spec.Endpoint.URL)
if err != nil {
return nil, nil, err
}
v.Conn = conn
ops = append(ops, pgvector.WithConn(conn))
}
if embedder != nil {
ops = append(ops, pgvector.WithEmbedder(embedder))
} else {
llm, _ := openai.New()
embedder, _ = embeddings.NewEmbedder(llm)
}
ops = append(ops, pgvector.WithEmbedder(embedder))
if collectionName != "" {
ops = append(ops, pgvector.WithCollectionName(collectionName))
v.PGVector.CollectionName = collectionName
} else {
ops = append(ops, pgvector.WithCollectionName(vs.Spec.PGVector.CollectionName))
}
store, err := pgvector.New(ctx, ops...)
if err != nil {
return nil, nil, err
}
v.Store = store
return v, finish, nil
}

// RemoveExist remove exist document from pgvector
// Note: it is currently assumed that the embedder of a knowledge base is constant that means the result of embedding a fixed document is fixed,
// disregarding the case where the embedder changes (and if it does, a lot of processing will need to be done in many places, not just here)
func (s *PGVectorStore) RemoveExist(ctx context.Context, document []lanchaingoschema.Document) (doc []lanchaingoschema.Document, err error) {
// get collection_uuid from collection_table, if null, means no exits
collectionUUID := ""
sql := fmt.Sprintf(`SELECT uuid FROM %s WHERE name = $1 ORDER BY name limit 1`, s.PGVector.CollectionTableName)
err = s.Conn.QueryRow(ctx, sql, s.PGVector.CollectionName).Scan(&collectionUUID)
if collectionUUID == "" {
return document, err
}
in := make([]string, 0)
for _, d := range document {
in = append(in, d.PageContent)
}
sql = fmt.Sprintf(`SELECT document, cmetadata FROM %s WHERE collection_id = $1 AND document in ('%s')`, s.PGVector.EmbeddingTableName, strings.Join(in, "', '"))
rows, err := s.Conn.Query(ctx, sql, collectionUUID)
if err != nil {
return nil, err
}
res := make(map[string]lanchaingoschema.Document, 0)
for rows.Next() {
doc := lanchaingoschema.Document{}
if err := rows.Scan(&doc.PageContent, &doc.Metadata); err != nil {
return nil, err
}
res[doc.PageContent] = doc
}
if len(res) == 0 {
return document, nil
}
if len(res) == len(document) {
return nil, nil
}
doc = make([]lanchaingoschema.Document, 0, len(document))
for _, d := range document {
has, ok := res[d.PageContent]
if !ok || !reflect.DeepEqual(has.Metadata, d.Metadata) {
doc = append(doc, d)
}
}
return doc, nil
}
Loading

0 comments on commit 48f3401

Please sign in to comment.