Skip to content

Commit

Permalink
fix: get more accurate emberder segmentation
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <[email protected]>
  • Loading branch information
Abirdcfly committed Dec 12, 2023
1 parent 555ff89 commit 5b82b5f
Show file tree
Hide file tree
Showing 20 changed files with 291 additions and 187 deletions.
44 changes: 25 additions & 19 deletions api/base/v1alpha1/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,39 @@ import (
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/dynamic"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeagi/arcadia/pkg/utils"
)

func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client) (string, error) {
func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client, cli dynamic.Interface) (string, error) {
if e.Spec.Enpoint == nil || e.Spec.Enpoint.AuthSecret == nil {
return "", nil
}
authSecret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{Name: e.Spec.Enpoint.AuthSecret.Name, Namespace: e.Namespace}, authSecret)
if err != nil {
if err := utils.ValidClient(c, cli); err != nil {
return "", err
}
return string(authSecret.Data["apiKey"]), nil
}

func (e Embedder) AuthAPIKeyByDynamicCli(ctx context.Context, cli dynamic.Interface) (string, error) {
if e.Spec.Enpoint == nil || e.Spec.Enpoint.AuthSecret == nil {
return "", nil
}
authSecret := &corev1.Secret{}
obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}).
Namespace(e.GetNamespace()).Get(ctx, e.Spec.Enpoint.AuthSecret.Name, metav1.GetOptions{})
if err != nil {
return "", err
}
err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), authSecret)
if err != nil {
return "", err
if c != nil {
if err := c.Get(ctx, types.NamespacedName{Name: e.Spec.Enpoint.AuthSecret.Name, Namespace: e.Namespace}, authSecret); err != nil {
return "", err
}
} else {
obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}).
Namespace(e.GetNamespace()).Get(ctx, e.Spec.Enpoint.AuthSecret.Name, metav1.GetOptions{})
if err != nil {
return "", err
}
err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), authSecret)
if err != nil {
return "", err
}
}
return string(authSecret.Data["apiKey"]), nil
}

type EmbeddingType string

const (
OpenAI EmbeddingType = "openai"
ZhiPuAI EmbeddingType = "zhipuai"
)
4 changes: 1 addition & 3 deletions api/base/v1alpha1/embedder_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package v1alpha1

import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeagi/arcadia/pkg/embeddings"
)

// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
Expand All @@ -30,7 +28,7 @@ type EmbedderSpec struct {
CommonSpec `json:",inline"`

// ServiceType indicates the source type of embedding service
Type embeddings.EmbeddingType `json:"type"`
Type EmbeddingType `json:"type"`

// Provider defines the provider info which provide this embedder service
Provider `json:"provider,omitempty"`
Expand Down
3 changes: 1 addition & 2 deletions api/base/v1alpha1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/llms"
)

Expand Down Expand Up @@ -121,7 +120,7 @@ func (worker Worker) BuildEmbedder() *Embedder {
DisplayName: worker.Spec.Model.Name,
Description: "Embedder created by Worker(OpenAI compatible)",
},
Type: embeddings.OpenAI,
Type: OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
Kind: "Worker",
Expand Down
7 changes: 3 additions & 4 deletions controllers/embedder_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/llms/openai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
)
Expand Down Expand Up @@ -156,20 +155,20 @@ func (r *EmbedderReconciler) check3rdPartyEmbedder(ctx context.Context, logger l
var msg string

// Check Auth availability
apiKey, err := instance.AuthAPIKey(ctx, r.Client)
apiKey, err := instance.AuthAPIKey(ctx, r.Client, nil)
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}

switch instance.Spec.Type {
case embeddings.ZhiPuAI:
case arcadiav1alpha1.ZhiPuAI:
embedClient := zhipuai.NewZhiPuAI(apiKey)
res, err := embedClient.Validate()
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
msg = res.String()
case embeddings.OpenAI:
case arcadiav1alpha1.OpenAI:
embedClient := openai.NewOpenAI(apiKey, instance.Spec.Enpoint.URL)
res, err := embedClient.Validate()
if err != nil {
Expand Down
108 changes: 50 additions & 58 deletions controllers/knowledgebase_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import (
"fmt"
"io"
"path/filepath"
"sync"
"time"

"github.com/go-logr/logr"
"github.com/minio/minio-go/v7"
"github.com/tmc/langchaingo/documentloaders"
langchainembeddings "github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/textsplitter"
"github.com/tmc/langchaingo/vectorstores/chroma"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand All @@ -46,8 +46,6 @@ import (
"github.com/kubeagi/arcadia/pkg/datasource"
pkgdocumentloaders "github.com/kubeagi/arcadia/pkg/documentloaders"
"github.com/kubeagi/arcadia/pkg/embeddings"
zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"github.com/kubeagi/arcadia/pkg/utils"
)

Expand All @@ -68,7 +66,9 @@ var (
// KnowledgeBaseReconciler reconciles a KnowledgeBase object
type KnowledgeBaseReconciler struct {
client.Client
Scheme *runtime.Scheme
Scheme *runtime.Scheme
mu sync.Mutex
HasHandledSuccessPath map[string]bool
}

//+kubebuilder:rbac:groups=arcadia.kubeagi.k8s.com.cn,resources=knowledgebases,verbs=get;list;watch;create;update;patch;delete
Expand Down Expand Up @@ -253,7 +253,7 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo
return errDataSourceNotReady
}

system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -290,6 +290,12 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo

errs := make([]error, 0)
for _, path := range group.Paths {
r.mu.Lock()
hasHandled := r.HasHandledSuccessPath[r.hasHandledPathKey(kb, group, path)]
r.mu.Unlock()
if hasHandled {
continue
}
fileDetail, ok := pathMap[path]
if !ok {
fileDetail = &arcadiav1alpha1.FileDetails{
Expand Down Expand Up @@ -361,6 +367,9 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo
fileDetail.UpdateErr(err, arcadiav1alpha1.FileProcessPhaseFailed)
continue
}
r.mu.Lock()
r.HasHandledSuccessPath[r.hasHandledPathKey(kb, group, path)] = true
r.mu.Unlock()
fileDetail.UpdateErr(nil, arcadiav1alpha1.FileProcessPhaseSucceeded)
}
return utilerrors.NewAggregate(errs)
Expand All @@ -383,69 +392,32 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
if !store.Status.IsReady() {
return errVectorStoreNotReady
}
var em langchainembeddings.Embedder
switch embedder.Spec.Provider.GetType() {
case arcadiav1alpha1.ProviderType3rdParty:
switch embedder.Spec.Type { // nolint: gocritic
case embeddings.ZhiPuAI:
apiKey, err := embedder.AuthAPIKey(ctx, r.Client)
if err != nil {
return err
}
em, err = zhipuaiembeddings.NewZhiPuAI(
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)),
)
if err != nil {
return err
}
}
case arcadiav1alpha1.ProviderTypeWorker:
gatway, err := config.GetGateway(ctx, r.Client)
if err != nil {
return err
}
if gatway == nil {
return fmt.Errorf("global config gateway not found")
}
refWorker := embedder.Spec.Worker
if refWorker == nil {
return fmt.Errorf("embedder.spec.worker not defined")
}
worker := &arcadiav1alpha1.Worker{}
if err := r.Client.Get(ctx, types.NamespacedName{Namespace: refWorker.GetNamespace(), Name: refWorker.Name}, worker); err != nil {
return err
}
refModel := worker.Spec.Model
if refModel == nil {
return fmt.Errorf("worker.spec.model not defined")
}
modelName := worker.MakeRegistrationModelName()
llm, err := openai.New(openai.WithModel(modelName), openai.WithBaseURL(gatway.APIServer), openai.WithToken("fake"))
if err != nil {
return err
}
em, err = langchainembeddings.NewEmbedder(llm)
if err != nil {
return err
}
em, err := embeddings.GetLangchainEmbedder(ctx, embedder, r.Client, nil)
if err != nil {
return err
}
data, err := io.ReadAll(file) // TODO Load large files in pieces to save memory
// TODO Line or single line byte exceeds embedder limit
if err != nil {
return err
}
dataReader := bytes.NewReader(data)
var documents []schema.Document
var loader documentloaders.Loader
switch filepath.Ext(fileName) {
case "txt":
case ".txt":
loader = documentloaders.NewText(dataReader)
case "csv":
case ".csv":
if v == arcadiav1alpha1.ObjectTypeQA {
loader = pkgdocumentloaders.NewQACSV(dataReader, fileName, "q", "a")
documents, err = loader.Load(ctx)
if err != nil {
return err
}
} else {
loader = documentloaders.NewCSV(dataReader)
}
case "html", "htm":
case ".html", ".htm":
loader = documentloaders.NewHTML(dataReader)
default:
loader = documentloaders.NewText(dataReader)
Expand Down Expand Up @@ -475,11 +447,15 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
// )
//}

documents, err := loader.LoadAndSplit(ctx, split)
if err != nil {
return err
if len(documents) == 0 {
documents, err = loader.LoadAndSplit(ctx, split)
if err != nil {
return err
}
}
for i, doc := range documents {
log.Info(fmt.Sprintf("document[%d]: embedding:%s, metadata:%v", i, doc.PageContent, doc.Metadata))
}

switch store.Spec.Type() { // nolint: gocritic
case arcadiav1alpha1.VectorStoreTypeChroma:
s, err := chroma.New(
Expand All @@ -500,6 +476,13 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
}

func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) {
r.mu.Lock()
for _, fg := range kb.Spec.FileGroups {
for _, path := range fg.Paths {
delete(r.HasHandledSuccessPath, r.hasHandledPathKey(kb, fg, path))
}
}
r.mu.Unlock()
vectorStore := &arcadiav1alpha1.VectorStore{}
if err := r.Get(ctx, types.NamespacedName{Name: kb.Spec.VectorStore.Name, Namespace: kb.Spec.VectorStore.GetNamespace()}, vectorStore); err != nil {
log.Error(err, "reconcile delete: get vector store error, may leave garbage data")
Expand All @@ -511,6 +494,7 @@ func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.
chroma.WithChromaURL(vectorStore.Spec.Enpoint.URL),
chroma.WithDistanceFunction(vectorStore.Spec.Chroma.DistanceFunction),
chroma.WithNameSpace(kb.VectorStoreCollectionName()),
chroma.WithOpenAiAPIKey("fake"),
)
if err != nil {
log.Error(err, "reconcile delete: init vector store error, may leave garbage data")
Expand All @@ -520,3 +504,11 @@ func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.
}
}
}

func (r *KnowledgeBaseReconciler) hasHandledPathKey(kb *arcadiav1alpha1.KnowledgeBase, filegroup arcadiav1alpha1.FileGroup, path string) string {
sourceName := ""
if filegroup.Source != nil {
sourceName = filegroup.Source.Name
}
return kb.Name + "/" + kb.Namespace + "/" + sourceName + "/" + path
}
4 changes: 2 additions & 2 deletions controllers/model_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (r *ModelReconciler) CheckModel(ctx context.Context, logger logr.Logger, in
var ds datasource.Datasource
var info any

system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return r.UpdateStatus(ctx, instance, err)
}
Expand Down Expand Up @@ -213,7 +213,7 @@ func (r *ModelReconciler) RemoveModel(ctx context.Context, logger logr.Logger, i
var ds datasource.Datasource
var info any

system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return r.UpdateStatus(ctx, instance, err)
}
Expand Down
2 changes: 1 addition & 1 deletion controllers/namespace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (r *NamespaceReconciler) SetupWithManager(mgr ctrl.Manager) error {
}

func (r *NamespaceReconciler) ossClient(ctx context.Context) (*datasource.OSS, error) {
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client)
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
klog.Errorf("get system datasource error %s", err)
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions controllers/versioneddataset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (r *VersionedDatasetReconciler) preUpdate(ctx context.Context, logger logr.
func (r *VersionedDatasetReconciler) checkStatus(ctx context.Context, logger logr.Logger, instance *v1alpha1.VersionedDataset) (bool, []v1alpha1.FileStatus, error) {
// TODO: Currently, we think there is only one default minio environment,
// so we get the minio client directly through the configuration.
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client)
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
logger.Error(err, "Failed to get system datasource")
return false, nil, err
Expand All @@ -232,7 +232,7 @@ func (r *VersionedDatasetReconciler) checkStatus(ctx context.Context, logger log
}

func (r *VersionedDatasetReconciler) removeBucketFiles(ctx context.Context, logger logr.Logger, instance *v1alpha1.VersionedDataset) error {
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client)
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
logger.Error(err, "Failed to get system datasource")
return err
Expand Down
2 changes: 1 addition & 1 deletion controllers/worker_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (r *WorkerReconciler) Initialize(ctx context.Context, logger logr.Logger, i

func (r *WorkerReconciler) reconcile(ctx context.Context, logger logr.Logger, worker *arcadiav1alpha1.Worker) error {
// reconcile worker instance
system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return fmt.Errorf("failed to get system datasource with %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions graphql-server/go-server/pkg/embedder/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/kubeagi/arcadia/graphql-server/go-server/graph/generated"
"github.com/kubeagi/arcadia/graphql-server/go-server/pkg/common"
graphqlutils "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/utils"
"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/utils"
)

Expand Down Expand Up @@ -104,7 +103,7 @@ func CreateEmbedder(ctx context.Context, c dynamic.Interface, input generated.Cr
URL: input.Endpointinput.URL,
},
},
Type: embeddings.EmbeddingType(servicetype),
Type: v1alpha1.EmbeddingType(servicetype),
},
}

Expand Down
Loading

0 comments on commit 5b82b5f

Please sign in to comment.