diff --git a/api/app-node/common_type.go b/api/app-node/common_type.go index 18493f19b..06bc4bedc 100644 --- a/api/app-node/common_type.go +++ b/api/app-node/common_type.go @@ -25,8 +25,15 @@ import ( const ( InputLengthAnnotationKey = v1alpha1.Group + `/input-rules` OutputLengthAnnotationKey = v1alpha1.Group + `/output-rules` + + // ConversationKnowledgebaseName is the placeholder name of the conversation knowledgebase + ConversationKnowledgebaseName = "conversation_knowledgebase_" ) +func IsPlaceholderConversationKnowledgebase(name string) bool { + return name == ConversationKnowledgebaseName +} + type Ref struct { Kind string `json:"kind,omitempty"` Group string `json:"group,omitempty"` diff --git a/api/app-node/retriever/v1alpha1/mergerretriever_types.go b/api/app-node/retriever/v1alpha1/mergerretriever_types.go new file mode 100644 index 000000000..cb8991ea9 --- /dev/null +++ b/api/app-node/retriever/v1alpha1/mergerretriever_types.go @@ -0,0 +1,76 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + node "github.com/kubeagi/arcadia/api/app-node" + "github.com/kubeagi/arcadia/api/base/v1alpha1" +) + +// MergerRetrieverSpec defines the desired state of MergerRetriever +type MergerRetrieverSpec struct { + v1alpha1.CommonSpec `json:",inline"` +} + +// MergerRetrieverStatus defines the observed state of MergerRetriever +type MergerRetrieverStatus struct { + // ObservedGeneration is the last observed generation. + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // ConditionedStatus is the current status + v1alpha1.ConditionedStatus `json:",inline"` +} + +//+kubebuilder:object:root=true +//+kubebuilder:subresource:status + +// MergerRetriever is the Schema for the MergerRetriever API +type MergerRetriever struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec MergerRetrieverSpec `json:"spec,omitempty"` + Status MergerRetrieverStatus `json:"status,omitempty"` +} + +//+kubebuilder:object:root=true + +// MergerRetrieverList contains a list of MergerRetriever +type MergerRetrieverList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []MergerRetriever `json:"items"` +} + +func init() { + SchemeBuilder.Register(&MergerRetriever{}, &MergerRetrieverList{}) +} + +var _ node.Node = (*MergerRetriever)(nil) + +func (c *MergerRetriever) SetRef() { + annotations := node.SetRefAnnotations(c.GetAnnotations(), []node.Ref{node.RetrieverRef.Len(1)}, []node.Ref{node.RetrievalQAChainRef.Len(1)}) + if c.GetAnnotations() == nil { + c.SetAnnotations(annotations) + } + for k, v := range annotations { + c.Annotations[k] = v + } +} diff --git a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go index 648d0d494..be0faddb5 100644 --- a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go @@ -138,6 +138,97 @@ func (in *KnowledgeBaseRetrieverStatus) DeepCopy() *KnowledgeBaseRetrieverStatus return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetriever) DeepCopyInto(out *MergerRetriever) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetriever. +func (in *MergerRetriever) DeepCopy() *MergerRetriever { + if in == nil { + return nil + } + out := new(MergerRetriever) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetriever) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverList) DeepCopyInto(out *MergerRetrieverList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]MergerRetriever, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverList. +func (in *MergerRetrieverList) DeepCopy() *MergerRetrieverList { + if in == nil { + return nil + } + out := new(MergerRetrieverList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *MergerRetrieverList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverSpec) DeepCopyInto(out *MergerRetrieverSpec) { + *out = *in + out.CommonSpec = in.CommonSpec +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverSpec. +func (in *MergerRetrieverSpec) DeepCopy() *MergerRetrieverSpec { + if in == nil { + return nil + } + out := new(MergerRetrieverSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *MergerRetrieverStatus) DeepCopyInto(out *MergerRetrieverStatus) { + *out = *in + in.ConditionedStatus.DeepCopyInto(&out.ConditionedStatus) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MergerRetrieverStatus. +func (in *MergerRetrieverStatus) DeepCopy() *MergerRetrieverStatus { + if in == nil { + return nil + } + out := new(MergerRetrieverStatus) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *MultiQueryRetriever) DeepCopyInto(out *MultiQueryRetriever) { *out = *in diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 8df013ed8..a240083cb 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -87,6 +87,7 @@ type ComplexityRoot struct { EnableRerank func(childComplexity int) int EnableUploadFile func(childComplexity int) int Knowledgebase func(childComplexity int) int + Knowledgebases func(childComplexity int) int Llm func(childComplexity int) int MaxLength func(childComplexity int) int MaxTokens func(childComplexity int) int @@ -1096,6 +1097,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Application.Knowledgebase(childComplexity), true + case "Application.knowledgebases": + if e.complexity.Application.Knowledgebases == nil { + break + } + + return e.complexity.Application.Knowledgebases(childComplexity), true + case "Application.llm": if e.complexity.Application.Llm == nil { break @@ -5172,10 +5180,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -5496,7 +5509,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 @@ -9995,6 +10013,47 @@ func (ec *executionContext) fieldContext_Application_conversionWindowSize(ctx co return fc, nil } +func (ec *executionContext) _Application_knowledgebases(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Application_knowledgebases(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Knowledgebases, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]*string) + fc.Result = res + return ec.marshalOString2ᚕᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Application_knowledgebases(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Application", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Application_knowledgebase(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Application_knowledgebase(ctx, field) if err != nil { @@ -11686,6 +11745,8 @@ func (ec *executionContext) fieldContext_ApplicationMutation_updateApplicationCo return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -11795,6 +11856,8 @@ func (ec *executionContext) fieldContext_ApplicationQuery_getApplication(ctx con return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -28836,6 +28899,8 @@ func (ec *executionContext) fieldContext_RAG_application(ctx context.Context, fi return ec.fieldContext_Application_maxTokens(ctx, field) case "conversionWindowSize": return ec.fieldContext_Application_conversionWindowSize(ctx, field) + case "knowledgebases": + return ec.fieldContext_Application_knowledgebases(ctx, field) case "knowledgebase": return ec.fieldContext_Application_knowledgebase(ctx, field) case "scoreThreshold": @@ -39168,7 +39233,7 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} + fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "knowledgebases", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "systemPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools", "enableRerank", "rerankModel", "enableMultiQuery", "chatTimeout", "enableUploadFile", "chunkSize", "chunkOverlap", "batchSize"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -39245,6 +39310,13 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte return it, err } it.Knowledgebase = data + case "knowledgebases": + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("knowledgebases")) + data, err := ec.unmarshalOString2ᚕᚖstring(ctx, v) + if err != nil { + return it, err + } + it.Knowledgebases = data case "scoreThreshold": ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scoreThreshold")) data, err := ec.unmarshalOFloat2ᚖfloat64(ctx, v) @@ -40566,6 +40638,8 @@ func (ec *executionContext) _Application(ctx context.Context, sel ast.SelectionS out.Values[i] = ec._Application_maxTokens(ctx, field, obj) case "conversionWindowSize": out.Values[i] = ec._Application_conversionWindowSize(ctx, field, obj) + case "knowledgebases": + out.Values[i] = ec._Application_knowledgebases(ctx, field, obj) case "knowledgebase": out.Values[i] = ec._Application_knowledgebase(ctx, field, obj) case "scoreThreshold": diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 9aa286ef1..1c3c2e8ca 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -54,6 +54,8 @@ type Application struct { MaxTokens *int `json:"maxTokens,omitempty"` // conversionWindowSize 对话轮次 ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` // scoreThreshold 最终返回结果的最低相似度 @@ -1707,6 +1709,8 @@ type UpdateApplicationConfigInput struct { ConversionWindowSize *int `json:"conversionWindowSize,omitempty"` // knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 Knowledgebase *string `json:"knowledgebase,omitempty"` + // knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + Knowledgebases []*string `json:"knowledgebases,omitempty"` // scoreThreshold 最终返回结果的最低相似度 ScoreThreshold *float64 `json:"scoreThreshold,omitempty"` // numDocuments 最终返回结果的引用上限 diff --git a/apiserver/graph/schema/application.gql b/apiserver/graph/schema/application.gql index 82804af47..ee68e3d33 100644 --- a/apiserver/graph/schema/application.gql +++ b/apiserver/graph/schema/application.gql @@ -78,6 +78,7 @@ mutation updateApplicationConfig($input: UpdateApplicationConfigInput!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn @@ -131,6 +132,7 @@ query getApplication($name: String!, $namespace: String!){ maxTokens conversionWindowSize knowledgebase + knowledgebases scoreThreshold numDocuments docNullReturn diff --git a/apiserver/graph/schema/application.graphqls b/apiserver/graph/schema/application.graphqls index f6c04afe7..4637a3e31 100644 --- a/apiserver/graph/schema/application.graphqls +++ b/apiserver/graph/schema/application.graphqls @@ -59,10 +59,15 @@ type Application { """ conversionWindowSize: Int + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] + """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") """ scoreThreshold 最终返回结果的最低相似度 @@ -383,7 +388,12 @@ input UpdateApplicationConfigInput { """ knowledgebase 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,目前一个应用只支持0或1个知识库 """ - knowledgebase: String + knowledgebase: String @deprecated(reason: "Use knowledgebases") + + """ + knowledgebases 指当前知识库应用使用的知识库,即 Kind 为 KnowledgeBase 的 CR 的名称,支持选择零个或一个或多个 + """ + knowledgebases: [String] """ scoreThreshold 最终返回结果的最低相似度 diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index 4fe3bca10..cbadaebf0 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -33,6 +33,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + appnode "github.com/kubeagi/arcadia/api/app-node" apiagent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" apidocumentloader "github.com/kubeagi/arcadia/api/app-node/documentloader/v1alpha1" @@ -65,7 +66,7 @@ func addDefaultValue(gApp *generated.Application, app *v1alpha1.Application) { if len(app.Spec.Nodes) > 0 { return } - gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") + // gApp.DocNullReturn = pointer.String("未找到您询问的内容,请详细描述您的问题") gApp.NumDocuments = pointer.Int(5) gApp.ScoreThreshold = pointer.Float64(0.3) gApp.Temperature = pointer.Float64(0.7) @@ -136,7 +137,7 @@ func cr2app(prompt *apiprompt.Prompt, chainConfig *apichain.CommonChainConfig, r case "llm": gApp.Llm = node.Ref.Name case "knowledgebase": - gApp.Knowledgebase = pointer.String(node.Ref.Name) + gApp.Knowledgebases = append(gApp.Knowledgebases, pointer.String(node.Ref.Name)) } } if retriever != nil { @@ -429,6 +430,9 @@ func ListApplicationMeatadatas(ctx context.Context, c client.Client, input gener } func UpdateApplicationConfig(ctx context.Context, c client.Client, input generated.UpdateApplicationConfigInput) (*generated.Application, error) { + input.Knowledgebases = ConvertKnowledgebase2Knowledgebases(input.Knowledgebase, input.Knowledgebases) // nolint:staticcheck + hasKnowledgebase := utils.HasValues(input.Knowledgebases) || pointer.BoolDeref(input.EnableUploadFile, false) // input has knowledgebases or enable upload file(may have conversation knowledgebase) + // check tool name not duplicated if len(input.Tools) != 0 { key := make(map[string]bool, len(input.Tools)) for _, tool := range input.Tools { @@ -438,6 +442,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat key[tool.Name] = true } } + key := types.NamespacedName{Namespace: input.Namespace, Name: input.Name} // get application cr, if not exist, return error app := &v1alpha1.Application{} @@ -510,12 +515,43 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } } + // create or update placeholder conversation knowledgebase, if enable upload file + var conversationKnowledgebase *v1alpha1.KnowledgeBase + if pointer.BoolDeref(input.EnableUploadFile, false) { + conversationKnowledgebase = &v1alpha1.KnowledgeBase{ + ObjectMeta: metav1.ObjectMeta{ + Name: appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + Spec: v1alpha1.KnowledgeBaseSpec{ + CommonSpec: v1alpha1.CommonSpec{ + DisplayName: "conversation knowledgebase placeholder", + Description: "conversation knowledgebase placeholder", + }, + Type: v1alpha1.KnowledgeBaseTypeConversation, + }, + } + if _, err := controllerutil.CreateOrUpdate(ctx, c, conversationKnowledgebase, func() error { + return nil + }); err != nil { + return nil, err + } + } else { + conversationKnowledgebase = &v1alpha1.KnowledgeBase{ + ObjectMeta: metav1.ObjectMeta{ + Name: appnode.ConversationKnowledgebaseName, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, conversationKnowledgebase) + } + // create or update chain var ( chainConfig *apichain.CommonChainConfig retriever *apiretriever.CommonRetrieverConfig ) - if utils.HasValue(input.Knowledgebase) { + if hasKnowledgebase { qachain := &apichain.RetrievalQAChain{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -547,6 +583,13 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + llmchain := &apichain.LLMChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, llmchain) chainConfig = &qachain.Spec.CommonChainConfig } else { llmchain := &apichain.LLMChain{ @@ -580,12 +623,22 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat }); err != nil { return nil, err } + qachain := &apichain.RetrievalQAChain{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, qachain) chainConfig = &llmchain.Spec.CommonChainConfig } // create or update retrievers - // knowledgebaseRetriever (must have) -> multiQueryRetriever (optional) -> rerankRetriever (optional) -> Output - hasKnowledgebaseRetriever := utils.HasValue(input.Knowledgebase) + // (must have) (must have) (optional) (optional) + // knowledgebase1 -> knowledgebaseRetriever1 + // knowledgebase2 -> knowledgebaseRetriever2 --> mergerRetriever --> multiQueryRetriever --> rerankRetriever --> Output + // conversation knowledgebase (placeholder or has true data) -> knowledgebaseRetriever3 + hasKnowledgebaseRetriever := hasKnowledgebase hasMultiQueryRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableMultiQuery, false) hasRerankRetriever := hasKnowledgebaseRetriever && pointer.BoolDeref(input.EnableRerank, false) rerankModel := "" @@ -617,6 +670,36 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat return nil, err } retriever = &knowledgebaseRetriever.Spec.CommonRetrieverConfig + } else { + knowledgebaseRetriever = &apiretriever.KnowledgeBaseRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, knowledgebaseRetriever) + } + + if hasKnowledgebaseRetriever { + mergerRetriever := &apiretriever.MergerRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + if _, err = controllerutil.CreateOrUpdate(ctx, c, mergerRetriever, func() error { + return nil + }); err != nil { + return nil, err + } + } else { + mergerRetriever := &apiretriever.MergerRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, mergerRetriever) } if hasMultiQueryRetriever { @@ -653,7 +736,16 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat return nil, err } retriever = &multiQueryRetriever.Spec.CommonRetrieverConfig + } else { + multiQueryRetriever := &apiretriever.MultiQueryRetriever{ + ObjectMeta: metav1.ObjectMeta{ + Name: input.Name, + Namespace: input.Namespace, + }, + } + _ = c.Delete(ctx, multiQueryRetriever) } + if hasRerankRetriever { rerankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ @@ -695,17 +787,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat if rerankRetriever.Spec.Model != nil { rerankModel = rerankRetriever.Spec.Model.Name } - } - if !hasMultiQueryRetriever { - multiQueryRetriever := &apiretriever.MultiQueryRetriever{ - ObjectMeta: metav1.ObjectMeta{ - Name: input.Name, - Namespace: input.Namespace, - }, - } - _ = c.Delete(ctx, multiQueryRetriever) - } - if !hasRerankRetriever { + } else { reRankRetriever := &apiretriever.RerankRetriever{ ObjectMeta: metav1.ObjectMeta{ Name: input.Name, @@ -714,6 +796,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } _ = c.Delete(ctx, reRankRetriever) } + // create or update agent for tools var agent *apiagent.Agent if len(input.Tools) != 0 { @@ -765,7 +848,7 @@ func UpdateApplicationConfig(ctx context.Context, c client.Client, input generat } func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfigInput, hasMultiQueryRetriever, hasRerankRetriever bool) error { - app.Spec.Nodes = redefineNodes(input.Knowledgebase, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) + app.Spec.Nodes = redefineNodes(input.Knowledgebases, input.Namespace, input.Name, input.Llm, input.Tools, hasMultiQueryRetriever, hasRerankRetriever, input.EnableUploadFile) app.Spec.Prologue = pointer.StringDeref(input.Prologue, app.Spec.Prologue) app.Spec.ShowRespInfo = pointer.BoolDeref(input.ShowRespInfo, app.Spec.ShowRespInfo) app.Spec.ShowRetrievalInfo = pointer.BoolDeref(input.ShowRetrievalInfo, app.Spec.ShowRetrievalInfo) @@ -779,7 +862,7 @@ func mutateApp(app *v1alpha1.Application, input generated.UpdateApplicationConfi } // redefineNodes redefine nodes in application -func redefineNodes(knowledgebase *string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { +func redefineNodes(knowledgebases []*string, namespace string, name string, llmName string, tools []*generated.ToolInput, hasMultiQueryRetriever, hasRerankRetriever bool, enableUploadFile *bool) (nodes []v1alpha1.Node) { nodes = []v1alpha1.Node{ { NodeConfig: v1alpha1.NodeConfig{ @@ -808,41 +891,25 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName }, NextNodeName: []string{"chain-node"}, }, - } - if pointer.BoolDeref(enableUploadFile, false) { - nodes = append(nodes, v1alpha1.Node{ + { NodeConfig: v1alpha1.NodeConfig{ - Name: "documentloader-node", - DisplayName: "documentloader", - Description: "文档加载,可选", + Name: "llm-node", + DisplayName: "llm", + Description: "设定大模型的访问信息", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "DocumentLoader", - Name: name, + Kind: "LLM", + Name: llmName, Namespace: &namespace, }, }, NextNodeName: []string{"chain-node"}, - }) - } - nodes = append(nodes, v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "llm-node", - DisplayName: "llm", - Description: "设定大模型的访问信息", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "LLM", - Name: llmName, - Namespace: &namespace, - }, }, - NextNodeName: []string{"chain-node"}, - }) + } if len(tools) != 0 { nodes[len(nodes)-1].NextNodeName = []string{"chain-node", "agent-node"} } - if knowledgebase == nil { + if !utils.HasValues(knowledgebases) { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "chain-node", @@ -858,50 +925,66 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } else { - nodes = append(nodes, - v1alpha1.Node{ - NodeConfig: v1alpha1.NodeConfig{ - Name: "knowledgebase-node", - DisplayName: "知识库", - Description: "连接知识库", - Ref: &v1alpha1.TypedObjectReference{ - APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBase", - Name: pointer.StringDeref(knowledgebase, ""), - Namespace: &namespace, + for _, knowledgebase := range knowledgebases { + nodes = append(nodes, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "knowledgebase-node", + DisplayName: "知识库", + Description: "连接知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBase", + Name: pointer.StringDeref(knowledgebase, ""), + Namespace: &namespace, + }, }, + NextNodeName: []string{"retriever-node"}, }, - NextNodeName: []string{"retriever-node"}, - }, + v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "retriever-node", + DisplayName: "从知识库提取信息的retriever", + Description: "连接应用和知识库", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), + Kind: "KnowledgeBaseRetriever", + Name: name, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"mergerretriever-node"}, + }) + } + mergerRetrierNextNodeName := "chain-node" + multiqueryRetrieverNodeName := "chain-node" + switch { + case hasMultiQueryRetriever: + mergerRetrierNextNodeName = "multiqueryretriever-node" + if hasRerankRetriever { + multiqueryRetrieverNodeName = "rerankretriever-node" + } + case !hasMultiQueryRetriever && hasRerankRetriever: + mergerRetrierNextNodeName = "rerankretriever-node" + case !hasMultiQueryRetriever && !hasRerankRetriever: + mergerRetrierNextNodeName = "chain-node" + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ - Name: "retriever-node", - DisplayName: "从知识库提取信息的retriever", - Description: "连接应用和知识库", + Name: "mergerretriever-node", + DisplayName: "知识库合并retriever", + Description: "知识库合并retriever", Ref: &v1alpha1.TypedObjectReference{ APIGroup: pointer.String("retriever.arcadia.kubeagi.k8s.com.cn"), - Kind: "KnowledgeBaseRetriever", + Kind: "MergerRetriever", Name: name, Namespace: &namespace, }, }, - NextNodeName: []string{"chain-node"}, + NextNodeName: []string{mergerRetrierNextNodeName}, }) - knowledgebaseRetrierNextNodeName := "chain-node" - switch { - case hasMultiQueryRetriever: - knowledgebaseRetrierNextNodeName = "multiqueryretriever-node" - case !hasMultiQueryRetriever && hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "rerankretriever-node" - case !hasMultiQueryRetriever && !hasRerankRetriever: - knowledgebaseRetrierNextNodeName = "chain-node" - } - nodes[len(nodes)-1].NextNodeName = []string{knowledgebaseRetrierNextNodeName} if hasMultiQueryRetriever { - nextNodeName := "chain-node" - if hasRerankRetriever { - nextNodeName = "rerankretriever-node" - } nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -915,7 +998,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName Namespace: &namespace, }, }, - NextNodeName: []string{nextNodeName}, + NextNodeName: []string{multiqueryRetrieverNodeName}, }) } if hasRerankRetriever { @@ -951,6 +1034,7 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"Output"}, }) } + if len(tools) != 0 { nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ @@ -967,6 +1051,24 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName NextNodeName: []string{"chain-node"}, }) } + + if pointer.BoolDeref(enableUploadFile, false) { + nodes = append(nodes, v1alpha1.Node{ + NodeConfig: v1alpha1.NodeConfig{ + Name: "documentloader-node", + DisplayName: "documentloader", + Description: "文档加载,可选", + Ref: &v1alpha1.TypedObjectReference{ + APIGroup: pointer.String("arcadia.kubeagi.k8s.com.cn"), + Kind: "DocumentLoader", + Name: name, + Namespace: &namespace, + }, + }, + NextNodeName: []string{"chain-node"}, + }) + } + nodes = append(nodes, v1alpha1.Node{ NodeConfig: v1alpha1.NodeConfig{ Name: "Output", @@ -982,6 +1084,17 @@ func redefineNodes(knowledgebase *string, namespace string, name string, llmName return nodes } +// deprecated +func ConvertKnowledgebase2Knowledgebases(knowledgebase *string, knowledgebases []*string) []*string { + if knowledgebases != nil { + return knowledgebases + } + if knowledgebase == nil { + return nil + } + return []*string{knowledgebase} +} + func UploadIcon(ctx context.Context, client client.Client, icon, appName, namespace string) (string, error) { if strings.HasPrefix(icon, "data:image") { imgBytes, err := pkgutils.ParseBase64ImageBytes(icon) diff --git a/apiserver/pkg/chat/chat_server.go b/apiserver/pkg/chat/chat_server.go index 5bc2a4329..3f712e294 100644 --- a/apiserver/pkg/chat/chat_server.go +++ b/apiserver/pkg/chat/chat_server.go @@ -29,7 +29,6 @@ import ( langchainllms "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/memory" "github.com/tmc/langchaingo/prompts" - langchainschema "github.com/tmc/langchaingo/schema" "golang.org/x/sync/errgroup" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -303,32 +302,29 @@ func (cs *ChatServer) ListPromptStarters(ctx context.Context, req APPMetadata, l if finish != nil { defer finish() } - v, ok := outArg[base.LangchaingoRetrieverKeyInArg] - if ok { - r, ok := v.(langchainschema.Retriever) - if ok { - doc, err := r.GetRelevantDocuments(ctx, "") - if err != nil { - return nil, err - } - for _, d := range doc { - hasAnswer := false - // has answer, means qa.csv, just return the question - v, ok := d.Metadata[documentloaders.AnswerCol] - if ok { - answer, ok := v.(string) - if ok && answer != "" { - question := strings.TrimSuffix(d.PageContent, "\na: "+answer) - promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) - hasAnswer = true - if len(promptStarters) == limit { - break - } + retrievers, err := base.GetRetrieversFromArg(outArg) + if err == nil && len(retrievers) > 0 { + doc, err := retrievers[0].GetRelevantDocuments(ctx, "") + if err != nil { + return nil, err + } + for _, d := range doc { + hasAnswer := false + // has answer, means qa.csv, just return the question + v, ok := d.Metadata[documentloaders.AnswerCol] + if ok { + answer, ok := v.(string) + if ok && answer != "" { + question := strings.TrimSuffix(d.PageContent, "\na: "+answer) + promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: ")) + hasAnswer = true + if len(promptStarters) == limit { + break } } - if !hasAnswer { - content.WriteString(d.PageContent + "\n") - } + } + if !hasAnswer { + content.WriteString(d.PageContent + "\n") } } } diff --git a/apiserver/pkg/utils/structured.go b/apiserver/pkg/utils/structured.go index 20d1916dc..f72c319e0 100644 --- a/apiserver/pkg/utils/structured.go +++ b/apiserver/pkg/utils/structured.go @@ -38,3 +38,7 @@ func MapStr2Any(input map[string]string) map[string]any { func HasValue(s *string) bool { return s != nil && strings.TrimSpace(*s) != "" } + +func HasValues(s []*string) bool { + return len(s) > 0 && strings.TrimSpace(*s[0]) != "" +} diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/config/samples/app_shared_llm_service_zhipu.yaml b/config/samples/app_shared_llm_service_zhipu.yaml index 9762e60a7..e44aa08ff 100644 --- a/config/samples/app_shared_llm_service_zhipu.yaml +++ b/config/samples/app_shared_llm_service_zhipu.yaml @@ -5,7 +5,7 @@ metadata: namespace: arcadia type: Opaque data: - apiKey: "MTZlZDcxYzcwMDE0NGFiMjIyMmI5YmEwZDFhMTBhZTUuUTljWVZtWWxmdjlnZGtDeQ==" # replace this with your API key + apiKey: "YTgyNTlhNjFmN2EwZGYzNmQ5N2Q3ZDIwOGVlMTQ0NTUuODc5OGJyeldwaGUzWUlCOA==" # replace this with your API key --- apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: LLM diff --git a/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml b/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml new file mode 100644 index 000000000..221beee1b --- /dev/null +++ b/config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml @@ -0,0 +1,16 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Model +metadata: + name: bce-reranker + namespace: arcadia +spec: + displayName: "bce-reranker-large" + description: | + BCEmbedding是由网易有道开发的中英双语和跨语种语义表征算法模型库,其中包含 EmbeddingModel和 RerankerModel两类基础模型。 + EmbeddingModel专门用于生成语义向量,在语义搜索和问答中起着关键作用,而 RerankerModel擅长优化语义搜索结果和语义相关顺序精排。 + + Github: https://github.com/netease-youdao/BCEmbedding + HuggingFace: https://huggingface.co/maidalun1020/bce-reranker-base_v1 + types: "reranking" + modelScopeRepo: maidalun/bce-reranker-base_v1 + modelSource: modelscope diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index a74ee8d19..4f25507cf 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(Also a KubeBB Component) for KubeAGI Arcadia type: application -version: 0.3.29 +version: 0.3.30 appVersion: "0.2.1" keywords: diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml new file mode 100644 index 000000000..590f3797a --- /dev/null +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_mergerretrievers.yaml @@ -0,0 +1,98 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: mergerretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: MergerRetriever + listKind: MergerRetrieverList + plural: mergerretrievers + singular: mergerretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: MergerRetriever is the Schema for the MergerRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: MergerRetrieverSpec defines the desired state of MergerRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + type: object + status: + description: MergerRetrieverStatus defines the observed state of MergerRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/go.mod b/go.mod index 8f8600ff6..d4076a31f 100644 --- a/go.mod +++ b/go.mod @@ -215,4 +215,4 @@ require ( sigs.k8s.io/yaml v1.3.0 // indirect ) -replace github.com/tmc/langchaingo => github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 // branch dev +replace github.com/tmc/langchaingo => github.com/Abirdcfly/langchaingo v0.0.0-20240402052449-3e0b4b8fedf4 // branch dev diff --git a/go.sum b/go.sum index b8718c1b0..d1126d6a5 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/99designs/gqlgen v0.17.40 h1:/l8JcEVQ93wqIfmH9VS1jsAkwm6eAF1NwQn3N+SDqBY= github.com/99designs/gqlgen v0.17.40/go.mod h1:b62q1USk82GYIVjC60h02YguAZLqYZtvWml8KkhJps4= +github.com/Abirdcfly/langchaingo v0.0.0-20240402052449-3e0b4b8fedf4 h1:32Ne6AEwWJr8Hi40LgI+i5oVc8nB3hbJZ8g3hcEaqb8= +github.com/Abirdcfly/langchaingo v0.0.0-20240402052449-3e0b4b8fedf4/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= @@ -549,8 +551,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91 h1:4VbKHgpTrG/EPWIn4n3FMIedvz3z2aYL2E0Z7RIEvik= -github.com/kubeagi/langchaingo v0.0.0-20240312075057-ca2f549e8d91/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= diff --git a/pkg/appruntime/app_runtime.go b/pkg/appruntime/app_runtime.go index 9045ca778..d4b70284a 100644 --- a/pkg/appruntime/app_runtime.go +++ b/pkg/appruntime/app_runtime.go @@ -25,8 +25,6 @@ import ( "strings" langchaingoschema "github.com/tmc/langchaingo/schema" - apierrors "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/strings/slices" "sigs.k8s.io/controller-runtime/pkg/client" @@ -149,32 +147,12 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha base.InputIsNeedStreamKeyInArg: input.NeedStream, base.LangchaingoChatMessageHistoryKeyInArg: input.History, // Use an empty context before run - "context": "", + "context": "", + base.ConversationIDInArg: input.ConversationID, } if a.Spec.DocNullReturn != "" { out[base.APPDocNullReturn] = a.Spec.DocNullReturn } - if input.ConversationID != "" { // means this is not a new conversation - conversationKnowledgebaseExist := true - kb := &arcadiav1alpha1.KnowledgeBase{} - err := cli.Get(ctx, types.NamespacedName{Namespace: a.Namespace, Name: input.ConversationID}, kb) - if err != nil { - if apierrors.IsNotFound(err) { - conversationKnowledgebaseExist = false - // TODO We can search for whether there should be a conversation knowledgebase from the pg - klog.FromContext(ctx).V(5).Info("conversation knowledgebase not exist", "ConversationID", input.ConversationID) - } else { - return output, err - } - } - if conversationKnowledgebaseExist { - if kb.Status.IsReady() { - out[base.ConversationKnowledgeBaseInArg] = kb - } else { - klog.FromContext(ctx).V(3).Info("conversation knowledgebase not ready", "ConversationID", input.ConversationID) - } - } - } visited := make(map[string]bool) waitRunningNodes := list.New() for _, v := range a.StartingNodes { @@ -194,7 +172,6 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha waitRunningNodes.PushBack(e) continue } - klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name())) defer func() { if r := recover(); r != nil { klog.FromContext(ctx).Info(fmt.Sprintf("Recovered from node:%s error:%s stack:%s", e.Name(), r, string(debug.Stack()))) @@ -280,6 +257,9 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha case "multiqueryretriever": logger.V(3).Info("initnode multiqueryretriever") return retriever.NewMultiQueryRetriever(baseNode), nil + case "mergeretriever": + logger.V(3).Info("initnode mergeretriever") + return retriever.NewMergerRetriever(baseNode), nil default: return nil, err } diff --git a/pkg/appruntime/base/keyword.go b/pkg/appruntime/base/keyword.go index 465e19443..eece1ee61 100644 --- a/pkg/appruntime/base/keyword.go +++ b/pkg/appruntime/base/keyword.go @@ -16,6 +16,12 @@ limitations under the License. package base +import ( + "errors" + + langchainschema "github.com/tmc/langchaingo/schema" +) + const ( InputQuestionKeyInArg = "question" InputIsNeedStreamKeyInArg = "_need_stream" @@ -25,9 +31,58 @@ const ( MapReduceDocumentOutputInArg = "_mapreduce_document_answer" OutputAnserStreamChanKeyInArg = "_answer_stream" RuntimeRetrieverReferencesKeyInArg = "_references" - LangchaingoRetrieverKeyInArg = "retriever" + LangchaingoRetrieversKeyInArg = "retrievers" LangchaingoLLMKeyInArg = "llm" LangchaingoPromptKeyInArg = "prompt" APPDocNullReturn = "_app_doc_null_return" ConversationKnowledgeBaseInArg = "_conversation_knowledgebase" // the conversation Knowledgebase cr in args, status has ready + ConversationIDInArg = "_conversation_id" ) + +func GetInputQuestionFromArg(args map[string]any) (string, error) { + q, ok := args[InputQuestionKeyInArg] + if !ok { + return "", errors.New("no question in args") + } + query, ok := q.(string) + if !ok || len(query) == 0 { + return "", errors.New("empty question") + } + return query, nil +} + +func GetRetrieversFromArg(args map[string]any) ([]langchainschema.Retriever, error) { + v, ok := args[LangchaingoRetrieversKeyInArg] + if !ok { + return nil, errors.New("no retrievers in args") + } + retrievers, ok := v.([]langchainschema.Retriever) + if !ok { + return nil, errors.New("retrievers not []schema.Retriever") + } + return retrievers, nil +} + +func GetAPPDocNullReturnFromArg(args map[string]any) (string, error) { + v, ok := args[APPDocNullReturn] + if !ok { + return "", nil + } + docNullReturn, ok := v.(string) + if !ok { + return "", errors.New("app doc null return not string type") + } + return docNullReturn, nil +} + +// AddKnowledgebaseRetrieverToArg add knowledgebase retriever to args +// Note: only knowledgebase retriever will be appended, other components like qachain will only use the first retriever in args +func AddKnowledgebaseRetrieverToArg(args map[string]any, retriever langchainschema.Retriever) map[string]any { + if _, exist := args[LangchaingoRetrieversKeyInArg]; !exist { + args[LangchaingoRetrieversKeyInArg] = make([]langchainschema.Retriever, 0) + } + retrievers := args[LangchaingoRetrieversKeyInArg].([]langchainschema.Retriever) + retrievers = append(retrievers, retriever) + args[LangchaingoRetrieversKeyInArg] = retrievers + return args +} diff --git a/pkg/appruntime/base/node.go b/pkg/appruntime/base/node.go index 7122c608b..c401fc639 100644 --- a/pkg/appruntime/base/node.go +++ b/pkg/appruntime/base/node.go @@ -111,8 +111,8 @@ func (c *BaseNode) Init(_ context.Context, _ client.Client, _ map[string]any) er return nil } -func (c *BaseNode) Run(_ context.Context, _ client.Client, _ map[string]any) (map[string]any, error) { - return nil, nil +func (c *BaseNode) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + return args, nil } func (c *BaseNode) Ready() (bool, string) { diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index dacd99a56..bf100d537 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -27,7 +27,6 @@ import ( langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" - "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" @@ -57,19 +56,6 @@ func (l *LLMChain) Init(ctx context.Context, cli client.Client, args map[string] } func (l *LLMChain) Run(ctx context.Context, cli client.Client, args map[string]any) (outArgs map[string]any, err error) { - if args[base.ConversationKnowledgeBaseInArg] != nil { - chain := NewRetrievalQAChain(l.BaseNode) - chain.Instance = &v1alpha1.RetrievalQAChain{ - Spec: v1alpha1.RetrievalQAChainSpec{ - CommonChainConfig: v1alpha1.CommonChainConfig{ - Memory: v1alpha1.Memory{ - ConversionWindowSize: pointer.Int(v1alpha1.DefaultConversionWindowSize), - }, - }, - }, - } - return chain.Run(ctx, cli, args) - } v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") @@ -122,7 +108,7 @@ func (l *LLMChain) Run(ctx context.Context, cli client.Client, args map[string]a // Add the agent output to the context if it's not empty if args[base.AgentOutputInArg] != nil { klog.FromContext(ctx).V(5).Info(fmt.Sprintf("get answer from upstream: %s", args[base.AgentOutputInArg])) - args["context"] = fmt.Sprintf("%s\n%s", args["context"], args[base.AgentOutputInArg]) + args["context"] = fmt.Sprintf("%s\n %s", args["context"], fmt.Sprintf("Tool Output of %s is: %s", args["question"].(string), args[base.AgentOutputInArg].(string))) } // Add the mapReduceDocument output to the context if it's not empty if args[base.MapReduceDocumentOutputInArg] != nil { diff --git a/pkg/appruntime/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go index 366586c94..3b332bd4d 100644 --- a/pkg/appruntime/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -27,12 +27,9 @@ import ( langchainschema "github.com/tmc/langchaingo/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" - "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" - arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/appruntime/log" appruntimeretriever "github.com/kubeagi/arcadia/pkg/appruntime/retriever" @@ -77,27 +74,12 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ if !ok { return args, errors.New("prompt not prompts.FormatPrompter") } - vkb, ok := args[base.ConversationKnowledgeBaseInArg] - if ok { - kb, ok := vkb.(*arcadiav1alpha1.KnowledgeBase) - if !ok { - return args, errors.New("knowledgebase not arcadiav1alpha1.KnowledgeBase") - } - args, finish, err := appruntimeretriever.GenerateKnowledgebaseRetriever(ctx, cli, kb.Name, kb.Namespace, apiretriever.CommonRetrieverConfig{ScoreThreshold: pointer.Float32(apiretriever.DefaultScoreThreshold), NumDocuments: 20}, args) - if err != nil { - return args, err - } - defer finish() - } - v3, ok := args["retriever"] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v3.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err } + retriever := retrieversInArg[0] v4, ok := args[base.LangchaingoChatMessageHistoryKeyInArg] if !ok { @@ -130,6 +112,10 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ } } + docs, err := retriever.GetRelevantDocuments(ctx, args["question"].(string)) + if err != nil { + return args, fmt.Errorf("can't get doc from retriever: %w", err) + } /* Note: we can only add context to retriever's documents not args["context"] if we use ConversationalRetrievalQA chain see https://github.com/kubeagi/langchaingo/blob/ca2f549e8d91788fd76a9f8706afcaae617275c5/chains/stuff_documents.go#L54-L69 @@ -139,14 +125,17 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ // Add the agent output to qachain context if args[base.AgentOutputInArg] != nil { klog.FromContext(ctx).V(5).Info(fmt.Sprintf("get context from agent: %s", args[base.AgentOutputInArg])) - docs, err := retriever.GetRelevantDocuments(ctx, args["question"].(string)) - if err != nil { - return args, fmt.Errorf("can't get doc from retriever: %w", err) - } doc := langchainschema.Document{PageContent: fmt.Sprintf("Tool Output of %s is: %s", args["question"].(string), args[base.AgentOutputInArg].(string))} docs = append(docs, doc) retriever = &appruntimeretriever.Fakeretriever{Docs: docs, Name: "AddAgentOutputRetriever"} } + if len(docs) == 0 { + docNullReturn, err := base.GetAPPDocNullReturnFromArg(args) + if err == nil && len(docNullReturn) > 0 { + return args, &base.RetrieverGetNullDocError{Msg: docNullReturn} + } + } + // Add the mapReduceDocument output to the context if it's not empty if args[base.MapReduceDocumentOutputInArg] != nil { // Note: Now args[base.MapReduceDocumentOutputInArg] will not be null only for the first "summarize content of uploaded document" Chat request @@ -167,28 +156,42 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli client.Client, args map[ condenseQustionGenerator.CallbacksHandler = log.KLogHandler{LogLevel: 3} chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), condenseQustionGenerator, retriever, GetMemory(llm, instance.Spec.Memory, history, "", "")) chain.RephraseQuestion = false + chain.ReturnSourceDocuments = true l.ConversationalRetrievalQA = chain args["query"] = args["question"] - var out string + var ( + out string + outputValues map[string]any + ) needStream := false needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { options = append(options, chains.WithStreamingFunc(stream(args))) - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args, options...) } else { if len(options) > 0 { - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args, options...) } else { - out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args) + outputValues, err = chains.Call(ctx, l.ConversationalRetrievalQA, args) } } + // _llmChainDefaultOutputKey + out, _ = outputValues["text"].(string) out, err = handleNoErrNoOut(ctx, needStream, out, err, l.ConversationalRetrievalQA, args, options) klog.FromContext(ctx).V(5).Info("use retrievalqachain, blocking out:" + out) if err == nil { args[base.OutputAnserKeyInArg] = out + // _conversationalRetrievalQADefaultSourceDocumentKey + doc, ok := outputValues["source_documents"].([]langchainschema.Document) + if ok { + _, refs := appruntimeretriever.ConvertDocuments(ctx, doc, "retrievalqachain") + // note: the references in args will be replaced, not append + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + } return args, nil } + return args, fmt.Errorf("retrievalqachain run error: %w", err) } diff --git a/pkg/appruntime/knowledgebase/knowledgebase.go b/pkg/appruntime/knowledgebase/knowledgebase.go index e57ee21d1..8448f7a36 100644 --- a/pkg/appruntime/knowledgebase/knowledgebase.go +++ b/pkg/appruntime/knowledgebase/knowledgebase.go @@ -23,6 +23,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" ) @@ -39,6 +40,9 @@ func NewKnowledgebase(baseNode base.BaseNode) *Knowledgebase { } func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return nil + } instance := &v1alpha1.KnowledgeBase{} if err := cli.Get(ctx, types.NamespacedName{Namespace: k.RefNamespace(), Name: k.Ref.Name}, instance); err != nil { return fmt.Errorf("can't find the knowledgebase in cluster: %w", err) @@ -47,10 +51,9 @@ func (k *Knowledgebase) Init(ctx context.Context, cli client.Client, _ map[strin return nil } -func (k *Knowledgebase) Run(_ context.Context, _ client.Client, args map[string]any) (map[string]any, error) { - return args, nil -} - func (k *Knowledgebase) Ready() (isReady bool, msg string) { + if appnode.IsPlaceholderConversationKnowledgebase(k.Ref.Name) { + return true, "" + } return k.Instance.Status.IsReadyOrGetReadyMessage() } diff --git a/pkg/appruntime/retriever/common.go b/pkg/appruntime/retriever/common.go index 50bf2e2f9..b45245d4f 100644 --- a/pkg/appruntime/retriever/common.go +++ b/pkg/appruntime/retriever/common.go @@ -79,6 +79,7 @@ func AddReferencesToArgs(args map[string]any, refs []Reference) map[string]any { return args } +// ConvertDocuments, convert raw doc to what we want doc, for example, knowledgebase doc should add answer into page content func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, retrieverName string) (newDocs []langchaingoschema.Document, refs []Reference) { logger := klog.FromContext(ctx) docLen := len(docs) @@ -108,7 +109,7 @@ func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, re doc.PageContent = doc.PageContent + joinStr + answer } } - if retrieverName == "multiquery" { + if retrieverName == "multiquery" || retrieverName == "retrievalqachain" { // pageContent may have the answer in previous steps, and we want this field only has question in reference output pageContent = strings.TrimSuffix(pageContent, joinStr+answer) } diff --git a/pkg/appruntime/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go index 28450a6fd..e1c207da1 100644 --- a/pkg/appruntime/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,16 +18,16 @@ package retriever import ( "context" - "errors" "fmt" - langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" "k8s.io/utils/pointer" "sigs.k8s.io/controller-runtime/pkg/client" + appnode "github.com/kubeagi/arcadia/api/app-node" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" @@ -101,7 +101,21 @@ func (l *KnowledgeBaseRetriever) Cleanup() { func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, knowledgebaseName, knowledgebaseNamespace string, retrieverConfig apiretriever.CommonRetrieverConfig, args map[string]any) (outArg map[string]any, finish func(), err error) { knowledgebase := &v1alpha1.KnowledgeBase{} + isConversationKnowledgebase := appnode.IsPlaceholderConversationKnowledgebase(knowledgebaseName) + if isConversationKnowledgebase { + v, ok := args[base.ConversationIDInArg] + if ok { + conversationID, ok := v.(string) + if ok && conversationID != "" { + knowledgebaseName = conversationID + } + } + } if err := cli.Get(ctx, types.NamespacedName{Namespace: knowledgebaseNamespace, Name: knowledgebaseName}, knowledgebase); err != nil { + if isConversationKnowledgebase && apierrors.IsNotFound(err) { // When there is a conversationID, look for the corresponding conversation knowledgebase. This knowledgebase may not exist. This is not a error + // TODO We can search for whether there should be a conversation knowledgebase from the pg + return args, nil, nil + } return nil, nil, fmt.Errorf("can't find the knowledgebase in cluster: %w", err) } @@ -138,40 +152,14 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } retriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} - question, ok := args["question"] - if !ok { - return nil, finish, errors.New("no question in args") - } - query, ok := question.(string) - if !ok { - return nil, finish, errors.New("question not string") + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return nil, finish, err } docs, err := retriever.GetRelevantDocuments(ctx, query) if err != nil { return nil, finish, fmt.Errorf("can't get relevant documents: %w", err) } - oldDocs := make([]langchaingoschema.Document, 0) - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if ok { - // may exist other retriever, like conversation retriever - oldRetriever, ok := v.(langchaingoschema.Retriever) - if ok { - oldDocs, err = oldRetriever.GetRelevantDocuments(ctx, query) - if err != nil { - return nil, finish, fmt.Errorf("can't get old doc: %w", err) - } - } - } - if len(docs) == 0 && len(oldDocs) == 0 { - // FIXME: 需要决定当知识库找不到相关内容,但是conversation知识库存在文档时如何处理 - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, finish, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - } // pgvector get score means vector distance, similarity = 1 - vector distance // chroma get score means similarity // we want similarity finally. @@ -181,7 +169,7 @@ func GenerateKnowledgebaseRetriever(ctx context.Context, cli client.Client, know } } docs, refs := ConvertDocuments(ctx, docs, "knowledgebase") - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: append(docs, oldDocs...), Name: "KnowledgebaseRetriever"} - AddReferencesToArgs(args, refs) + args = AddReferencesToArgs(args, refs) + args = base.AddKnowledgebaseRetrieverToArg(args, &Fakeretriever{Docs: docs, Name: "KnowledgebaseRetriever"}) return args, finish, nil } diff --git a/pkg/appruntime/retriever/mergerretriever.go b/pkg/appruntime/retriever/mergerretriever.go new file mode 100644 index 000000000..698462871 --- /dev/null +++ b/pkg/appruntime/retriever/mergerretriever.go @@ -0,0 +1,75 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "context" + "fmt" + + langchainretrievers "github.com/tmc/langchaingo/retrievers" + "github.com/tmc/langchaingo/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/base" +) + +type MergerRetriever struct { + base.BaseNode + Instance *apiretriever.MergerRetriever +} + +func NewMergerRetriever(baseNode base.BaseNode) *MergerRetriever { + return &MergerRetriever{ + BaseNode: baseNode, + } +} + +func (l *MergerRetriever) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + instance := &apiretriever.MergerRetriever{} + if err := cli.Get(ctx, types.NamespacedName{Namespace: l.RefNamespace(), Name: l.BaseNode.Ref.Name}, instance); err != nil { + return fmt.Errorf("can't find the merger retriever in cluster: %w", err) + } + l.Instance = instance + return nil +} + +func (l *MergerRetriever) Run(ctx context.Context, _ client.Client, args map[string]any) (map[string]any, error) { + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err + } + r := langchainretrievers.NewMergerRetriever(retrievers) + query, err := base.GetInputQuestionFromArg(args) + if err != nil { + return args, err + } + docs, err := r.GetRelevantDocuments(ctx, query) + if err != nil { + return args, err + } + docs, refs := ConvertDocuments(ctx, docs, "mergerretriever") + // note: the references in args will be replaced, not append + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + args[base.LangchaingoRetrieversKeyInArg] = []schema.Retriever{&Fakeretriever{Docs: docs, Name: "mergerRetriever"}} + return args, nil +} + +func (l *MergerRetriever) Ready() (isReady bool, msg string) { + return l.Instance.Status.IsReadyOrGetReadyMessage() +} diff --git a/pkg/appruntime/retriever/multiqueryretriever.go b/pkg/appruntime/retriever/multiqueryretriever.go index 610ec9c1f..36c05b232 100644 --- a/pkg/appruntime/retriever/multiqueryretriever.go +++ b/pkg/appruntime/retriever/multiqueryretriever.go @@ -69,13 +69,9 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m return args, errors.New("empty question") } - v1, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v1.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrieversInArg, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err } v2, ok := args[base.LangchaingoLLMKeyInArg] @@ -88,7 +84,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m } prompt := prompts.NewPromptTemplate(_defaultQueryTemplate, []string{"question"}) llmchain := chains.NewLLMChain(llm, prompt, chains.WithCallback(log.KLogHandler{LogLevel: 3})) - multiqueryRetriever := retrievers.NewMultiQueryRetriever(retriever, llmchain, true) + multiqueryRetriever := retrievers.NewMultiQueryRetriever(retrieversInArg[0], llmchain, true) multiqueryRetriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} docs, err := multiqueryRetriever.GetRelevantDocuments(ctx, query) if err != nil { @@ -107,7 +103,7 @@ func (l *MultiQueryRetriever) Run(ctx context.Context, cli client.Client, args m newDocs, newRef := ConvertDocuments(ctx, newDocs, "multiquery") // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: newDocs, Name: "MultiqueryRetriever"}} return args, nil } diff --git a/pkg/appruntime/retriever/rerankretriever.go b/pkg/appruntime/retriever/rerankretriever.go index 4ae0c1fe6..f4ff88076 100644 --- a/pkg/appruntime/retriever/rerankretriever.go +++ b/pkg/appruntime/retriever/rerankretriever.go @@ -66,14 +66,8 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s return args, errors.New("empty references") } if len(references) == 0 { - v, exist := args[base.APPDocNullReturn] - if exist { - docNullReturn, ok := v.(string) - if ok && len(docNullReturn) > 0 { - return nil, &base.RetrieverGetNullDocError{Msg: docNullReturn} - } - } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: nil, Name: "RerankRetriever"} + klog.FromContext(ctx).V(3).Info("rerank retriever get no references, skip rerank") + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: nil, Name: "RerankRetriever"}} return args, nil } q, ok := args[base.InputQuestionKeyInArg] @@ -143,15 +137,11 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s // note: the references in args will be replaced, not append args[base.RuntimeRetrieverReferencesKeyInArg] = newRef - v, ok := args[base.LangchaingoRetrieverKeyInArg] - if !ok { - return args, errors.New("no retriever") - } - retriever, ok := v.(langchainschema.Retriever) - if !ok { - return args, errors.New("retriever not schema.Retriever") + retrievers, err := base.GetRetrieversFromArg(args) + if err != nil { + return args, err } - docs, err := retriever.GetRelevantDocuments(ctx, query) + docs, err := retrievers[0].GetRelevantDocuments(ctx, query) if err != nil { return args, fmt.Errorf("get relevant documents failed: %w", err) } @@ -163,7 +153,7 @@ func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[s } } } - args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs, Name: "RerankRetriever"} + args[base.LangchaingoRetrieversKeyInArg] = []langchainschema.Retriever{&Fakeretriever{Docs: newDocs, Name: "RerankRetriever"}} return args, nil } diff --git a/tests/example-test.sh b/tests/example-test.sh index 6e80dd971..fcd0e15ed 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -40,314 +40,11 @@ Timeout="${TimeoutSeconds}s" mkdir ${TempFilePath} || true env -function debugInfo { - if [[ $? -eq 0 ]]; then - exit 0 - fi - if [[ $debug -ne 0 ]]; then - exit 1 - fi - if [[ $GITHUB_ACTIONS == "true" ]]; then - warning "debugInfo start 🧐" - mkdir -p $LOG_DIR || true - df -h - - warning "1. Try to get all resources " - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-list.log - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A -oyaml --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-yaml.log - - warning "2. Try to describe all resources " - kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl describe -A >$LOG_DIR/describe-all-resources.log - - warning "3. Try to export kind logs to $LOG_DIR..." - kind export logs --name=${KindName} $LOG_DIR - sudo chown -R $USER:$USER $LOG_DIR - - warning "debugInfo finished ! " - warning "This means that some tests have failed. Please check the log. 🌚" - debug=1 - fi - exit 1 -} +source ./tests/scripts/utils.sh trap 'debugInfo $LINENO' ERR trap 'debugInfo $LINENO' EXIT debug=0 -function cecho() { - declare -A colors - colors=( - ['black']='\E[0;47m' - ['red']='\E[0;31m' - ['green']='\E[0;32m' - ['yellow']='\E[0;33m' - ['blue']='\E[0;34m' - ['magenta']='\E[0;35m' - ['cyan']='\E[0;36m' - ['white']='\E[0;37m' - ) - local defaultMSG="No message passed." - local defaultColor="black" - local defaultNewLine=true - while [[ $# -gt 1 ]]; do - key="$1" - case $key in - -c | --color) - color="$2" - shift - ;; - -n | --noline) - newLine=false - ;; - *) - # unknown option - ;; - esac - shift - done - message=${1:-$defaultMSG} # Defaults to default message. - color=${color:-$defaultColor} # Defaults to default color, if not specified. - newLine=${newLine:-$defaultNewLine} - echo -en "${colors[$color]}" - echo -en "$message" - if [ "$newLine" = true ]; then - echo - fi - tput sgr0 # Reset text attributes to normal without clearing screen. - return -} - -function warning() { - cecho -c 'yellow' "$@" -} - -function error() { - cecho -c 'red' "$@" -} - -function info() { - cecho -c 'blue' "$@" -} - -function waitPodReady() { - namespace=$1 - podLabel=$2 - START_TIME=$(date +%s) - while true; do - readStatus=$(kubectl -n${namespace} get po -l ${podLabel} --ignore-not-found=true -o json | jq -r '.items[0].status.conditions[] | select(."type"=="Ready") | .status') - if [[ $readStatus == "True" ]]; then - info "Pod:${podLabel} ready" - break - fi - kubectl -n${namespace} get po -l ${podLabel} - - CURRENT_TIME=$(date +%s) - ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) - if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then - error "Timeout reached" - kubectl describe po -n${namespace} -l ${podLabel} - kubectl get po -n${namespace} --show-labels - exit 1 - fi - sleep 5 - done -} - -function EnableAPIServerPortForward() { - waitPodReady "arcadia" "app=arcadia-apiserver" - if [ $portal_pid -ne 0 ]; then - kill $portal_pid >/dev/null 2>&1 - fi - echo "re port-forward apiserver..." - kubectl port-forward svc/arcadia-apiserver -n arcadia 8081:8081 >/dev/null 2>&1 & - portal_pid=$! - sleep 3 - info "port-forward apiserver in pid: $portal_pid" -} - -function waitCRDStatusReady() { - source=$1 - namespace=$2 - name=$3 - START_TIME=$(date +%s) - while true; do - readStatus=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].status') - message=$(kubectl -n${namespace} get ${source} ${name} --ignore-not-found=true -o json | jq -r '.status.conditions[0].message') - if [[ $readStatus == "True" ]]; then - info $message - if [[ ${source} == "KnowledgeBase" ]]; then - fileStatus=$(kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails[0].phase') - if [[ $fileStatus != "Succeeded" ]]; then - kubectl get knowledgebase -n $namespace $name -o json | jq -r '.status.fileGroupDetail[0].fileDetails' - exit 1 - fi - fi - break - fi - - CURRENT_TIME=$(date +%s) - ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) - if [[ ${source} == "Worker" ]]; then - if [ $ELAPSED_TIME -gt 1800 ]; then - error "Timeout reached" - exit 1 - fi - else - if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then - error "Timeout reached" - exit 1 - fi - fi - sleep 5 - done -} - -function getRespInAppChat() { - appname=$1 - namespace=$2 - query=$3 - conversationID=$4 - testStream=$5 - attempt=0 - while true; do - info "sleep 3 seconds" - sleep 3 - data=$(jq -n --arg appname "$appname" --arg query "$query" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname}') - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}") - ai_data=$(echo $resp | jq -r '.message') - references=$(echo $resp | jq -r '.references') - if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: ${query}" - echo "🤖: ${ai_data}" - echo "🔗: ${references}" - break - done - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - - if [ $testStream == "true" ]; then - attempt=0 - while true; do - info "sleep 5 seconds" - sleep 5 - info "just test stream mode" - data=$(jq -n --arg appname "$appname" --arg query "$query" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"streaming","conversation_id":$conversationID,"app_name":$appname}') - curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}" - if [[ $? -ne 0 ]]; then - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - EnableAPIServerPortForward - echo "and wait 60s for google api error" - sleep 60 - continue - fi - break - done - fi -} - -function fileUploadSummarise() { - appname=$1 - namespace=$2 - filename=$3 - attempt=0 - while true; do - info "sleep 3 seconds" - sleep 3 - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST --form file=@$filename --form app_name=$appname -H "namespace: ${namespace}" -H "Content-Type: multipart/form-data" http://127.0.0.1:8081/chat/conversations/file) - doc_data=$(echo $resp | jq -r '.document') - if [ -z "$doc_data" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: ${filename}" - echo "🤖: ${doc_data}" - break - done - file_id=$(echo $resp | jq -r '.document.object') - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - attempt=0 - while true; do - info "sleep 3 seconds to sumerize doc" - sleep 3 - data=$(jq -n --arg fileid "$file_id" --arg appname "$appname" --arg query "总结一下" --arg conversationID "$resp_conversation_id" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "files": [$fileid]}') - resp=$(curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}") - ai_data=$(echo $resp | jq -r '.message') - references=$(echo $resp | jq -r '.references') - if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then - echo $resp - EnableAPIServerPortForward - if [[ $resp == *"googleapi: Error"* ]]; then - echo "google api error, will retry after 60s" - sleep 60 - fi - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - continue - fi - echo "👤: 总结一下" - echo "🤖: ${ai_data}" - echo "🔗: ${references}" - break - done - resp_conversation_id=$(echo $resp | jq -r '.conversation_id') - - if [ $testStream == "true" ]; then - attempt=0 - while true; do - info "sleep 5 seconds" - sleep 5 - info "just test stream mode" - data=$(jq -n --arg fileid "$file_id" --arg appname "$appname" --arg query "总结一下" --arg conversationID "$resp_conversation_id" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "files": [$fileid]}') - curl --max-time $TimeoutSeconds -s --show-error -XPOST http://127.0.0.1:8081/chat --data "$data" -H "namespace: ${namespace}" - if [[ $? -ne 0 ]]; then - attempt=$((attempt + 1)) - if [ $attempt -gt $RETRY_COUNT ]; then - echo "❌: Failed. Retry count exceeded." - exit 1 - fi - echo "🔄: Failed. Attempt $attempt/$RETRY_COUNT" - EnableAPIServerPortForward - echo "and wait 60s for google api error" - sleep 60 - continue - fi - break - done - fi -} - info "1. create kind cluster" make kind df -h @@ -590,7 +287,14 @@ kubectl patch applications -n arcadia base-chat-with-knowledgebase-pgvector -p ' getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "飞天的主演是谁?" "" "true" info "8.2.3 QA app using knowledgebase base on pgvector and rerank" -kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce.yaml +if [[ $GITHUB_ACTIONS == "true" ]]; then + info "in github action, download model from huggingface" + #FIXME #kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce.yaml + kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml +else + info "in local, download model from modelscope" + kubectl apply -f config/samples/arcadia_v1alpha1_model_reranking_bce_modelscope.yaml +fi waitCRDStatusReady "Model" "arcadia" "bce-reranker" kubectl apply -f config/samples/arcadia_v1alpha1_worker_reranking_bce.yaml waitCRDStatusReady "Worker" "arcadia" "bce-reranker"