diff --git a/api/base/v1alpha1/embedder.go b/api/base/v1alpha1/embedder.go index 44e39118e..72337de75 100644 --- a/api/base/v1alpha1/embedder.go +++ b/api/base/v1alpha1/embedder.go @@ -26,33 +26,39 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/dynamic" "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeagi/arcadia/pkg/utils" ) -func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client) (string, error) { +func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client, cli dynamic.Interface) (string, error) { if e.Spec.Enpoint == nil || e.Spec.Enpoint.AuthSecret == nil { return "", nil } - authSecret := &corev1.Secret{} - err := c.Get(ctx, types.NamespacedName{Name: e.Spec.Enpoint.AuthSecret.Name, Namespace: e.Namespace}, authSecret) - if err != nil { + if err := utils.ValidClient(c, cli); err != nil { return "", err } - return string(authSecret.Data["apiKey"]), nil -} - -func (e Embedder) AuthAPIKeyByDynamicCli(ctx context.Context, cli dynamic.Interface) (string, error) { - if e.Spec.Enpoint == nil || e.Spec.Enpoint.AuthSecret == nil { - return "", nil - } authSecret := &corev1.Secret{} - obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}). - Namespace(e.GetNamespace()).Get(ctx, e.Spec.Enpoint.AuthSecret.Name, metav1.GetOptions{}) - if err != nil { - return "", err - } - err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), authSecret) - if err != nil { - return "", err + if c != nil { + if err := c.Get(ctx, types.NamespacedName{Name: e.Spec.Enpoint.AuthSecret.Name, Namespace: e.Namespace}, authSecret); err != nil { + return "", err + } + } else { + obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}). + Namespace(e.GetNamespace()).Get(ctx, e.Spec.Enpoint.AuthSecret.Name, metav1.GetOptions{}) + if err != nil { + return "", err + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), authSecret) + if err != nil { + return "", err + } } return string(authSecret.Data["apiKey"]), nil } + +type EmbeddingType string + +const ( + OpenAI EmbeddingType = "openai" + ZhiPuAI EmbeddingType = "zhipuai" +) diff --git a/api/base/v1alpha1/embedder_types.go b/api/base/v1alpha1/embedder_types.go index f017ec992..8c4a5f5d6 100644 --- a/api/base/v1alpha1/embedder_types.go +++ b/api/base/v1alpha1/embedder_types.go @@ -18,8 +18,6 @@ package v1alpha1 import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/kubeagi/arcadia/pkg/embeddings" ) // EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN! @@ -30,7 +28,7 @@ type EmbedderSpec struct { CommonSpec `json:",inline"` // ServiceType indicates the source type of embedding service - Type embeddings.EmbeddingType `json:"type"` + Type EmbeddingType `json:"type"` // Provider defines the provider info which provide this embedder service Provider `json:"provider,omitempty"` diff --git a/api/base/v1alpha1/worker.go b/api/base/v1alpha1/worker.go index d63ed1ef7..97da58118 100644 --- a/api/base/v1alpha1/worker.go +++ b/api/base/v1alpha1/worker.go @@ -24,7 +24,6 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/kubeagi/arcadia/pkg/embeddings" "github.com/kubeagi/arcadia/pkg/llms" ) @@ -121,7 +120,7 @@ func (worker Worker) BuildEmbedder() *Embedder { DisplayName: worker.Spec.Model.Name, Description: "Embedder created by Worker(OpenAI compatible)", }, - Type: embeddings.OpenAI, + Type: OpenAI, Provider: Provider{ Worker: &TypedObjectReference{ Kind: "Worker", diff --git a/controllers/embedder_controller.go b/controllers/embedder_controller.go index 8f35897ae..d32c78ca6 100644 --- a/controllers/embedder_controller.go +++ b/controllers/embedder_controller.go @@ -36,7 +36,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/predicate" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/embeddings" "github.com/kubeagi/arcadia/pkg/llms/openai" "github.com/kubeagi/arcadia/pkg/llms/zhipuai" ) @@ -156,20 +155,20 @@ func (r *EmbedderReconciler) check3rdPartyEmbedder(ctx context.Context, logger l var msg string // Check Auth availability - apiKey, err := instance.AuthAPIKey(ctx, r.Client) + apiKey, err := instance.AuthAPIKey(ctx, r.Client, nil) if err != nil { return r.UpdateStatus(ctx, instance, nil, err) } switch instance.Spec.Type { - case embeddings.ZhiPuAI: + case arcadiav1alpha1.ZhiPuAI: embedClient := zhipuai.NewZhiPuAI(apiKey) res, err := embedClient.Validate() if err != nil { return r.UpdateStatus(ctx, instance, nil, err) } msg = res.String() - case embeddings.OpenAI: + case arcadiav1alpha1.OpenAI: embedClient := openai.NewOpenAI(apiKey, instance.Spec.Enpoint.URL) res, err := embedClient.Validate() if err != nil { diff --git a/controllers/knowledgebase_controller.go b/controllers/knowledgebase_controller.go index 38fca7281..42f304723 100644 --- a/controllers/knowledgebase_controller.go +++ b/controllers/knowledgebase_controller.go @@ -23,13 +23,13 @@ import ( "fmt" "io" "path/filepath" + "sync" "time" "github.com/go-logr/logr" "github.com/minio/minio-go/v7" "github.com/tmc/langchaingo/documentloaders" - langchainembeddings "github.com/tmc/langchaingo/embeddings" - "github.com/tmc/langchaingo/llms/openai" + "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/textsplitter" "github.com/tmc/langchaingo/vectorstores/chroma" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -46,8 +46,6 @@ import ( "github.com/kubeagi/arcadia/pkg/datasource" pkgdocumentloaders "github.com/kubeagi/arcadia/pkg/documentloaders" "github.com/kubeagi/arcadia/pkg/embeddings" - zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" - "github.com/kubeagi/arcadia/pkg/llms/zhipuai" "github.com/kubeagi/arcadia/pkg/utils" ) @@ -68,7 +66,9 @@ var ( // KnowledgeBaseReconciler reconciles a KnowledgeBase object type KnowledgeBaseReconciler struct { client.Client - Scheme *runtime.Scheme + Scheme *runtime.Scheme + mu sync.Mutex + HasHandledSuccessPath map[string]bool } //+kubebuilder:rbac:groups=arcadia.kubeagi.k8s.com.cn,resources=knowledgebases,verbs=get;list;watch;create;update;patch;delete @@ -253,7 +253,7 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo return errDataSourceNotReady } - system, err := config.GetSystemDatasource(ctx, r.Client) + system, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { return err } @@ -290,6 +290,12 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo errs := make([]error, 0) for _, path := range group.Paths { + r.mu.Lock() + hasHandled := r.HasHandledSuccessPath[r.hasHandledPathKey(kb, group, path)] + r.mu.Unlock() + if hasHandled { + continue + } fileDetail, ok := pathMap[path] if !ok { fileDetail = &arcadiav1alpha1.FileDetails{ @@ -361,6 +367,9 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo fileDetail.UpdateErr(err, arcadiav1alpha1.FileProcessPhaseFailed) continue } + r.mu.Lock() + r.HasHandledSuccessPath[r.hasHandledPathKey(kb, group, path)] = true + r.mu.Unlock() fileDetail.UpdateErr(nil, arcadiav1alpha1.FileProcessPhaseSucceeded) } return utilerrors.NewAggregate(errs) @@ -383,51 +392,9 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge if !store.Status.IsReady() { return errVectorStoreNotReady } - var em langchainembeddings.Embedder - switch embedder.Spec.Provider.GetType() { - case arcadiav1alpha1.ProviderType3rdParty: - switch embedder.Spec.Type { // nolint: gocritic - case embeddings.ZhiPuAI: - apiKey, err := embedder.AuthAPIKey(ctx, r.Client) - if err != nil { - return err - } - em, err = zhipuaiembeddings.NewZhiPuAI( - zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), - ) - if err != nil { - return err - } - } - case arcadiav1alpha1.ProviderTypeWorker: - gatway, err := config.GetGateway(ctx, r.Client) - if err != nil { - return err - } - if gatway == nil { - return fmt.Errorf("global config gateway not found") - } - refWorker := embedder.Spec.Worker - if refWorker == nil { - return fmt.Errorf("embedder.spec.worker not defined") - } - worker := &arcadiav1alpha1.Worker{} - if err := r.Client.Get(ctx, types.NamespacedName{Namespace: refWorker.GetNamespace(), Name: refWorker.Name}, worker); err != nil { - return err - } - refModel := worker.Spec.Model - if refModel == nil { - return fmt.Errorf("worker.spec.model not defined") - } - modelName := worker.MakeRegistrationModelName() - llm, err := openai.New(openai.WithModel(modelName), openai.WithBaseURL(gatway.APIServer), openai.WithToken("fake")) - if err != nil { - return err - } - em, err = langchainembeddings.NewEmbedder(llm) - if err != nil { - return err - } + em, err := embeddings.GetLangchainEmbedder(ctx, embedder, r.Client, nil) + if err != nil { + return err } data, err := io.ReadAll(file) // TODO Load large files in pieces to save memory // TODO Line or single line byte exceeds embedder limit @@ -435,17 +402,22 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge return err } dataReader := bytes.NewReader(data) + var documents []schema.Document var loader documentloaders.Loader switch filepath.Ext(fileName) { - case "txt": + case ".txt": loader = documentloaders.NewText(dataReader) - case "csv": + case ".csv": if v == arcadiav1alpha1.ObjectTypeQA { loader = pkgdocumentloaders.NewQACSV(dataReader, fileName, "q", "a") + documents, err = loader.Load(ctx) + if err != nil { + return err + } } else { loader = documentloaders.NewCSV(dataReader) } - case "html", "htm": + case ".html", ".htm": loader = documentloaders.NewHTML(dataReader) default: loader = documentloaders.NewText(dataReader) @@ -475,11 +447,15 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge // ) //} - documents, err := loader.LoadAndSplit(ctx, split) - if err != nil { - return err + if len(documents) == 0 { + documents, err = loader.LoadAndSplit(ctx, split) + if err != nil { + return err + } + } + for i, doc := range documents { + log.V(5).Info(fmt.Sprintf("document[%d]: embedding:%s, metadata:%v", i, doc.PageContent, doc.Metadata)) } - switch store.Spec.Type() { // nolint: gocritic case arcadiav1alpha1.VectorStoreTypeChroma: s, err := chroma.New( @@ -500,6 +476,13 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge } func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.Logger, kb *arcadiav1alpha1.KnowledgeBase) { + r.mu.Lock() + for _, fg := range kb.Spec.FileGroups { + for _, path := range fg.Paths { + delete(r.HasHandledSuccessPath, r.hasHandledPathKey(kb, fg, path)) + } + } + r.mu.Unlock() vectorStore := &arcadiav1alpha1.VectorStore{} if err := r.Get(ctx, types.NamespacedName{Name: kb.Spec.VectorStore.Name, Namespace: kb.Spec.VectorStore.GetNamespace()}, vectorStore); err != nil { log.Error(err, "reconcile delete: get vector store error, may leave garbage data") @@ -511,6 +494,7 @@ func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr. chroma.WithChromaURL(vectorStore.Spec.Enpoint.URL), chroma.WithDistanceFunction(vectorStore.Spec.Chroma.DistanceFunction), chroma.WithNameSpace(kb.VectorStoreCollectionName()), + chroma.WithOpenAiAPIKey("fake"), ) if err != nil { log.Error(err, "reconcile delete: init vector store error, may leave garbage data") @@ -520,3 +504,11 @@ func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr. } } } + +func (r *KnowledgeBaseReconciler) hasHandledPathKey(kb *arcadiav1alpha1.KnowledgeBase, filegroup arcadiav1alpha1.FileGroup, path string) string { + sourceName := "" + if filegroup.Source != nil { + sourceName = filegroup.Source.Name + } + return kb.Name + "/" + kb.Namespace + "/" + sourceName + "/" + path +} diff --git a/controllers/model_controller.go b/controllers/model_controller.go index 707986b1c..455b9422a 100644 --- a/controllers/model_controller.go +++ b/controllers/model_controller.go @@ -179,7 +179,7 @@ func (r *ModelReconciler) CheckModel(ctx context.Context, logger logr.Logger, in var ds datasource.Datasource var info any - system, err := config.GetSystemDatasource(ctx, r.Client) + system, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { return r.UpdateStatus(ctx, instance, err) } @@ -213,7 +213,7 @@ func (r *ModelReconciler) RemoveModel(ctx context.Context, logger logr.Logger, i var ds datasource.Datasource var info any - system, err := config.GetSystemDatasource(ctx, r.Client) + system, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { return r.UpdateStatus(ctx, instance, err) } diff --git a/controllers/namespace_controller.go b/controllers/namespace_controller.go index cc19d4ef0..34cd30208 100644 --- a/controllers/namespace_controller.go +++ b/controllers/namespace_controller.go @@ -123,7 +123,7 @@ func (r *NamespaceReconciler) SetupWithManager(mgr ctrl.Manager) error { } func (r *NamespaceReconciler) ossClient(ctx context.Context) (*datasource.OSS, error) { - systemDatasource, err := config.GetSystemDatasource(ctx, r.Client) + systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { klog.Errorf("get system datasource error %s", err) return nil, err diff --git a/controllers/versioneddataset_controller.go b/controllers/versioneddataset_controller.go index e8f7bc0bd..c22224886 100644 --- a/controllers/versioneddataset_controller.go +++ b/controllers/versioneddataset_controller.go @@ -212,7 +212,7 @@ func (r *VersionedDatasetReconciler) preUpdate(ctx context.Context, logger logr. func (r *VersionedDatasetReconciler) checkStatus(ctx context.Context, logger logr.Logger, instance *v1alpha1.VersionedDataset) (bool, []v1alpha1.FileStatus, error) { // TODO: Currently, we think there is only one default minio environment, // so we get the minio client directly through the configuration. - systemDatasource, err := config.GetSystemDatasource(ctx, r.Client) + systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { logger.Error(err, "Failed to get system datasource") return false, nil, err @@ -232,7 +232,7 @@ func (r *VersionedDatasetReconciler) checkStatus(ctx context.Context, logger log } func (r *VersionedDatasetReconciler) removeBucketFiles(ctx context.Context, logger logr.Logger, instance *v1alpha1.VersionedDataset) error { - systemDatasource, err := config.GetSystemDatasource(ctx, r.Client) + systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { logger.Error(err, "Failed to get system datasource") return err diff --git a/controllers/worker_controller.go b/controllers/worker_controller.go index b9a7db49b..881298aee 100644 --- a/controllers/worker_controller.go +++ b/controllers/worker_controller.go @@ -159,7 +159,7 @@ func (r *WorkerReconciler) Initialize(ctx context.Context, logger logr.Logger, i func (r *WorkerReconciler) reconcile(ctx context.Context, logger logr.Logger, worker *arcadiav1alpha1.Worker) error { // reconcile worker instance - system, err := config.GetSystemDatasource(ctx, r.Client) + system, err := config.GetSystemDatasource(ctx, r.Client, nil) if err != nil { return fmt.Errorf("failed to get system datasource with %w", err) } diff --git a/graphql-server/go-server/pkg/embedder/embedder.go b/graphql-server/go-server/pkg/embedder/embedder.go index c1f62ec8f..b14d2446d 100644 --- a/graphql-server/go-server/pkg/embedder/embedder.go +++ b/graphql-server/go-server/pkg/embedder/embedder.go @@ -31,7 +31,6 @@ import ( "github.com/kubeagi/arcadia/graphql-server/go-server/graph/generated" "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/common" graphqlutils "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/utils" - "github.com/kubeagi/arcadia/pkg/embeddings" "github.com/kubeagi/arcadia/pkg/utils" ) @@ -104,7 +103,7 @@ func CreateEmbedder(ctx context.Context, c dynamic.Interface, input generated.Cr URL: input.Endpointinput.URL, }, }, - Type: embeddings.EmbeddingType(servicetype), + Type: v1alpha1.EmbeddingType(servicetype), }, } diff --git a/main.go b/main.go index 267a82caf..76d242d84 100644 --- a/main.go +++ b/main.go @@ -211,8 +211,9 @@ func main() { os.Exit(1) } if err = (&controllers.KnowledgeBaseReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + HasHandledSuccessPath: make(map[string]bool, 0), }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "KnowledgeBase") os.Exit(1) diff --git a/pkg/application/app_run.go b/pkg/application/app_run.go index fa4d456e3..90bca9c31 100644 --- a/pkg/application/app_run.go +++ b/pkg/application/app_run.go @@ -24,6 +24,7 @@ import ( "reflect" "k8s.io/client-go/dynamic" + "k8s.io/klog/v2" "k8s.io/utils/strings/slices" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" @@ -134,6 +135,7 @@ func (a *Application) Init(ctx context.Context, cli dynamic.Interface) (err erro a.StartingNodes = append(a.StartingNodes, current) } } + klog.Infof("init application success ending node: %s\n", a.EndingNode) return nil } diff --git a/pkg/application/chain/retrievalqachain.go b/pkg/application/chain/retrievalqachain.go index fd03f51a2..971f7e6f9 100644 --- a/pkg/application/chain/retrievalqachain.go +++ b/pkg/application/chain/retrievalqachain.go @@ -28,6 +28,7 @@ import ( "k8s.io/klog/v2" "github.com/kubeagi/arcadia/pkg/application/base" + appretriever "github.com/kubeagi/arcadia/pkg/application/retriever" ) type RetrievalQAChain struct { @@ -69,7 +70,13 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ dynamic.Interface, args ma } llmChain := chains.NewLLMChain(llm, prompt) - chain := chains.NewRetrievalQA(chains.NewStuffDocuments(llmChain), retriever) + var baseChain chains.Chain + if _, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok { + baseChain = appretriever.NewStuffDocuments(llmChain) + } else { + baseChain = chains.NewStuffDocuments(llmChain) + } + chain := chains.NewRetrievalQA(baseChain, retriever) l.RetrievalQA = chain args["query"] = args["question"] var out string diff --git a/pkg/application/retriever/knowledgebaseretriever.go b/pkg/application/retriever/knowledgebaseretriever.go index 0bc5f80ad..b6e50e6e4 100644 --- a/pkg/application/retriever/knowledgebaseretriever.go +++ b/pkg/application/retriever/knowledgebaseretriever.go @@ -20,7 +20,7 @@ import ( "context" "fmt" - langchainembeddings "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/chains" langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores/chroma" @@ -28,13 +28,12 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" + "k8s.io/klog/v2" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/application/base" "github.com/kubeagi/arcadia/pkg/embeddings" - zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" - "github.com/kubeagi/arcadia/pkg/llms/zhipuai" ) type KnowledgeBaseRetriever struct { @@ -82,21 +81,10 @@ func NewKnowledgeBaseRetriever(ctx context.Context, baseNode base.BaseNode, cli if err != nil { return nil, err } - var em langchainembeddings.Embedder - switch embedder.Spec.Type { // nolint: gocritic - case embeddings.ZhiPuAI: - apiKey, err := embedder.AuthAPIKeyByDynamicCli(ctx, cli) - if err != nil { - return nil, err - } - em, err = zhipuaiembeddings.NewZhiPuAI( - zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), - ) - if err != nil { - return nil, err - } + em, err := embeddings.GetLangchainEmbedder(ctx, embedder, nil, cli) + if err != nil { + return nil, err } - vectorStore := &v1alpha1.VectorStore{} obj, err = cli.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "vectorstores"}). Namespace(vectorStoreReq.GetNamespace()).Get(ctx, vectorStoreReq.Name, metav1.GetOptions{}) @@ -133,3 +121,61 @@ func (l *KnowledgeBaseRetriever) Run(ctx context.Context, _ dynamic.Interface, a args["retriever"] = l return args, nil } + +// KnowledgeBaseStuffDocuments is similar to chains.StuffDocuments but with new joinDocuments method +type KnowledgeBaseStuffDocuments struct { + chains.StuffDocuments +} + +var _ chains.Chain = KnowledgeBaseStuffDocuments{} + +func (c KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Document) string { + var text string + docLen := len(docs) + for k, doc := range docs { + answer := doc.Metadata["a"] + answerBytes, _ := answer.([]byte) + text += doc.PageContent + if len(answerBytes) != 0 { + text = text + "\na: " + string(answerBytes) + } + if k != docLen-1 { + text += c.Separator + } + } + klog.Infof("get related text: %s\n", text) + return text +} + +func NewStuffDocuments(llmChain *chains.LLMChain) KnowledgeBaseStuffDocuments { + return KnowledgeBaseStuffDocuments{ + StuffDocuments: chains.NewStuffDocuments(llmChain), + } +} + +func (c KnowledgeBaseStuffDocuments) Call(ctx context.Context, values map[string]any, options ...chains.ChainCallOption) (map[string]any, error) { + docs, ok := values[c.InputKey].([]langchaingoschema.Document) + if !ok { + return nil, fmt.Errorf("%w: %w", chains.ErrInvalidInputValues, chains.ErrInputValuesWrongType) + } + + inputValues := make(map[string]any) + for key, value := range values { + inputValues[key] = value + } + + inputValues[c.DocumentVariableName] = c.joinDocuments(docs) + return chains.Call(ctx, c.LLMChain, inputValues, options...) +} + +func (c KnowledgeBaseStuffDocuments) GetMemory() langchaingoschema.Memory { + return c.StuffDocuments.GetMemory() +} + +func (c KnowledgeBaseStuffDocuments) GetInputKeys() []string { + return c.StuffDocuments.GetInputKeys() +} + +func (c KnowledgeBaseStuffDocuments) GetOutputKeys() []string { + return c.StuffDocuments.GetOutputKeys() +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 54636eeb1..1323fb406 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -28,7 +28,6 @@ import ( "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/yaml" "k8s.io/client-go/dynamic" - "k8s.io/klog/v2" "k8s.io/utils/env" "sigs.k8s.io/controller-runtime/pkg/client" @@ -50,8 +49,8 @@ var ( ErrNoConfigStreamlit = fmt.Errorf("config Streamlit in comfigmap is not found") ) -func GetSystemDatasource(ctx context.Context, c client.Client) (*arcadiav1alpha1.Datasource, error) { - config, err := GetConfig(ctx, c) +func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { + config, err := GetConfig(ctx, c, cli) if err != nil { return nil, err } @@ -63,38 +62,26 @@ func GetSystemDatasource(ctx context.Context, c client.Client) (*arcadiav1alpha1 namespace = utils.GetCurrentNamespace() } source := &arcadiav1alpha1.Datasource{} - if err = c.Get(ctx, client.ObjectKey{Name: name, Namespace: namespace}, source); err != nil { - return nil, err - } - return source, err -} - -func GetSystemDatasourceDynamic(ctx context.Context, c dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { - config, err := GetConfigDynamic(ctx, c) - if err != nil { - return nil, err - } - name := config.SystemDatasource.Name - var namespace string - if config.SystemDatasource.Namespace != nil { - namespace = *config.SystemDatasource.Namespace + if c != nil { + if err = c.Get(ctx, client.ObjectKey{Name: name, Namespace: namespace}, source); err != nil { + return nil, err + } } else { - namespace = utils.GetCurrentNamespace() - } - source := &arcadiav1alpha1.Datasource{} - obj, err := c.Resource(schema.GroupVersionResource{Group: arcadiav1alpha1.GroupVersion.Group, Version: arcadiav1alpha1.GroupVersion.Version, Resource: "datasources"}). - Namespace(namespace).Get(ctx, name, v1.GetOptions{}) - if err != nil { - return nil, err + obj, err := cli.Resource(schema.GroupVersionResource{Group: arcadiav1alpha1.GroupVersion.Group, Version: arcadiav1alpha1.GroupVersion.Version, Resource: "datasources"}). + Namespace(namespace).Get(ctx, name, v1.GetOptions{}) + if err != nil { + return nil, err + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.Object, source) + if err != nil { + return nil, err + } } - if err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.Object, source); err != nil { - return nil, err - } - return source, nil + return source, err } -func GetGateway(ctx context.Context, c client.Client) (*Gateway, error) { - config, err := GetConfig(ctx, c) +func GetGateway(ctx context.Context, c client.Client, cli dynamic.Interface) (*Gateway, error) { + config, err := GetConfig(ctx, c, cli) if err != nil { return nil, err } @@ -105,7 +92,7 @@ func GetGateway(ctx context.Context, c client.Client) (*Gateway, error) { } func GetMinIO(ctx context.Context, c dynamic.Interface) (*MinIO, error) { - datasource, err := GetSystemDatasourceDynamic(ctx, c) + datasource, err := GetSystemDatasource(ctx, nil, c) if err != nil { return nil, err } @@ -136,42 +123,30 @@ func GetMinIO(ctx context.Context, c dynamic.Interface) (*MinIO, error) { return &m, nil } -func GetConfigDynamic(ctx context.Context, c dynamic.Interface) (config *Config, err error) { - cmName := env.GetString(EnvConfigKey, EnvConfigDefaultValue) - if cmName == "" { - return nil, ErrNoConfigEnv - } - cmNamespace := utils.GetCurrentNamespace() - u, err := c.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "configmaps"}).Namespace(cmNamespace).Get(ctx, cmName, v1.GetOptions{}) - if err != nil { - klog.Errorln("failed to get configmap resource", err, cmNamespace, cmName) - return nil, ErrNoConfig - } - data, found, err := unstructured.NestedStringMap(u.Object, "data") - if err != nil || !found { - klog.Errorln("failed to get data from configmap", err) - return nil, ErrNoConfig - } - value, ok := data["config"] - if !ok || len(value) == 0 { - klog.Errorln("no config file from configmap", err) - return nil, ErrNoConfig - } - if err = yaml.Unmarshal([]byte(value), &config); err != nil { +func GetConfig(ctx context.Context, c client.Client, cli dynamic.Interface) (config *Config, err error) { + if err := utils.ValidClient(c, cli); err != nil { return nil, err } - return config, nil -} - -func GetConfig(ctx context.Context, c client.Client) (config *Config, err error) { cmName := env.GetString(EnvConfigKey, EnvConfigDefaultValue) if cmName == "" { return nil, ErrNoConfigEnv } cmNamespace := utils.GetCurrentNamespace() cm := &corev1.ConfigMap{} - if err = c.Get(ctx, client.ObjectKey{Name: cmName, Namespace: cmNamespace}, cm); err != nil { - return nil, err + if c != nil { + if err = c.Get(ctx, client.ObjectKey{Name: cmName, Namespace: cmNamespace}, cm); err != nil { + return nil, err + } + } else { + obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "configmaps"}). + Namespace(cmNamespace).Get(ctx, cmName, v1.GetOptions{}) + if err != nil { + return nil, err + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), cm) + if err != nil { + return nil, err + } } value, ok := cm.Data["config"] if !ok || len(value) == 0 { @@ -184,7 +159,7 @@ func GetConfig(ctx context.Context, c client.Client) (config *Config, err error) } func GetVectorStore(ctx context.Context, c dynamic.Interface) (*arcadiav1alpha1.TypedObjectReference, error) { - config, err := GetConfigDynamic(ctx, c) + config, err := GetConfig(ctx, nil, c) if err != nil { return nil, err } @@ -196,7 +171,7 @@ func GetVectorStore(ctx context.Context, c dynamic.Interface) (*arcadiav1alpha1. // Get the configuration of streamlit tool func GetStreamlit(ctx context.Context, c client.Client) (*Streamlit, error) { - config, err := GetConfig(ctx, c) + config, err := GetConfig(ctx, c, nil) if err != nil { return nil, err } diff --git a/pkg/embeddings/embeddings.go b/pkg/embeddings/embeddings.go index 2a5fa25b8..77da5ce72 100644 --- a/pkg/embeddings/embeddings.go +++ b/pkg/embeddings/embeddings.go @@ -16,16 +16,80 @@ limitations under the License. package embeddings -import langchaingoembeddings "github.com/tmc/langchaingo/embeddings" +import ( + "context" + "fmt" -type EmbeddingType string + langchaingoembeddings "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms/openai" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/dynamic" + "sigs.k8s.io/controller-runtime/pkg/client" -const ( - OpenAI EmbeddingType = "openai" - ZhiPuAI EmbeddingType = "zhipuai" + "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/pkg/config" + zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" + "github.com/kubeagi/arcadia/pkg/llms/zhipuai" + "github.com/kubeagi/arcadia/pkg/utils" ) -type Embedding interface { - Type() EmbeddingType - langchaingoembeddings.Embedder +func GetLangchainEmbedder(ctx context.Context, e *v1alpha1.Embedder, c client.Client, cli dynamic.Interface) (em langchaingoembeddings.Embedder, err error) { + if err := utils.ValidClient(c, cli); err != nil { + return nil, err + } + switch e.Spec.Provider.GetType() { + case v1alpha1.ProviderType3rdParty: + switch e.Spec.Type { // nolint: gocritic + case v1alpha1.ZhiPuAI: + apiKey, err := e.AuthAPIKey(ctx, c, cli) + if err != nil { + return nil, err + } + return zhipuaiembeddings.NewZhiPuAI( + zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)), + ) + } + case v1alpha1.ProviderTypeWorker: + gateway, err := config.GetGateway(ctx, c, cli) + if err != nil { + return nil, err + } + if gateway == nil { + return nil, fmt.Errorf("global config gateway not found") + } + workerRef := e.Spec.Worker + if workerRef == nil { + return nil, fmt.Errorf("embedder.spec.worker not defined") + } + worker := &v1alpha1.Worker{} + if c != nil { + if err := c.Get(ctx, types.NamespacedName{Namespace: workerRef.GetNamespace(), Name: workerRef.Name}, worker); err != nil { + return nil, err + } + } else { + obj, err := cli.Resource(schema.GroupVersionResource{Group: v1alpha1.Group, Version: v1alpha1.Version, Resource: "workers"}). + Namespace(workerRef.GetNamespace()).Get(ctx, workerRef.Name, metav1.GetOptions{}) + if err != nil { + return nil, err + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), worker) + if err != nil { + return nil, err + } + } + modelRef := worker.Spec.Model + if modelRef == nil { + return nil, fmt.Errorf("worker.spec.model not defined") + } + modelName := worker.MakeRegistrationModelName() + llm, err := openai.New(openai.WithModel(modelName), openai.WithBaseURL(gateway.APIServer), openai.WithToken("fake")) + if err != nil { + return nil, err + } + return langchaingoembeddings.NewEmbedder(llm) + } + return nil, fmt.Errorf("unknown provider type") } diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 740250260..109a12147 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -49,7 +49,7 @@ func NewScheduler(ctx context.Context, c client.Client, instance *v1alpha1.Versi // TODO: Currently, we think there is only one default minio environment, // so we get the minio client directly through the configuration. - systemDatasource, err := config.GetSystemDatasource(ctx1, c) + systemDatasource, err := config.GetSystemDatasource(ctx1, c, nil) if err != nil { klog.Errorf("generate new scheduler error %s", err) cancel() diff --git a/pkg/utils/structured.go b/pkg/utils/structured.go index b1de892b3..2c26b0ac7 100644 --- a/pkg/utils/structured.go +++ b/pkg/utils/structured.go @@ -22,6 +22,8 @@ import ( "reflect" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/client-go/dynamic" + "sigs.k8s.io/controller-runtime/pkg/client" ) // UnstructuredToStructured convert unstructed object to a structured targe(must be a pointer) @@ -45,3 +47,13 @@ func UnstructuredToStructured(unstructuredObj *unstructured.Unstructured, target return nil } + +func ValidClient(c client.Client, cli dynamic.Interface) error { + if c == nil && cli == nil { + return fmt.Errorf("both client.Client and dynamic.Interface cannot be nil") + } + if c != nil && cli != nil { + return fmt.Errorf(" client.Client and dynamic.Interface cannot be set at the same time") + } + return nil +} diff --git a/pkg/worker/runner.go b/pkg/worker/runner.go index 7a0b5501a..0c8fc3d68 100644 --- a/pkg/worker/runner.go +++ b/pkg/worker/runner.go @@ -64,7 +64,7 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1. if model == nil { return nil, errors.New("nil model") } - gw, err := config.GetGateway(ctx, runner.c) + gw, err := config.GetGateway(ctx, runner.c, nil) if err != nil { return nil, fmt.Errorf("failed to get arcadia config with %w", err) } @@ -105,7 +105,7 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp if model == nil { return nil, errors.New("nil model") } - gw, err := config.GetGateway(ctx, runner.c) + gw, err := config.GetGateway(ctx, runner.c, nil) if err != nil { return nil, fmt.Errorf("failed to get arcadia config with %w", err) } diff --git a/tests/example-test.sh b/tests/example-test.sh index 4ca641393..b852217e0 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -264,4 +264,8 @@ waitCRDStatusReady "Application" "arcadia" "base-chat-with-knowledgebase" sleep 3 curl -XPOST http://127.0.0.1:8081/chat --data '{"query":"旷工最小计算单位为多少天?","response_mode":"blocking","conversion_id":"","app_name":"base-chat-with-knowledgebase", "app_namespace":"arcadia"}' | jq -e '.message' +info "10 show apiserver logs for debug" +kubectl logs --tail=100 -n arcadia -l app=arcadia-apiserver >/tmp/apiserver.log +cat /tmp/apiserver.log + info "all finished! ✅"