diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index faa381c85..e3af762f9 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -211,6 +211,11 @@ func GetApplication(ctx context.Context, c dynamic.Interface, name, namespace st if err != nil { return nil, err } + app := &v1alpha1.Application{} + if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "Application"), namespace, name, app); err != nil { + return nil, err + } + prompt := &apiprompt.Prompt{} if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "Prompt"), namespace, name, prompt); err != nil { return nil, err @@ -218,30 +223,37 @@ func GetApplication(ctx context.Context, c dynamic.Interface, name, namespace st var ( chainConfig *apichain.CommonChainConfig llmChainInput *apichain.LLMChainInput + retriever *apiretriever.KnowledgeBaseRetriever ) - qachain := &apichain.RetrievalQAChain{} - if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "RetrievalQAChain"), namespace, name, qachain); err != nil { - return nil, err - } - if qachain.UID != "" { - chainConfig = &qachain.Spec.CommonChainConfig - llmChainInput = &qachain.Spec.Input.LLMChainInput - } - llmchain := &apichain.LLMChain{} - if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "LLMChain"), namespace, name, llmchain); err != nil { - return nil, err - } - if llmchain.UID != "" { - chainConfig = &llmchain.Spec.CommonChainConfig - llmChainInput = &llmchain.Spec.Input - } - retriever := &apiretriever.KnowledgeBaseRetriever{} - if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "KnowledgeBaseRetriever"), namespace, name, retriever); err != nil { - return nil, err + hasKnowledgeBaseRetriever := false + for _, node := range app.Spec.Nodes { + if node.Ref != nil && node.Ref.APIGroup != nil && *node.Ref.APIGroup == apiretriever.Group { + hasKnowledgeBaseRetriever = true + break + } } - app := &v1alpha1.Application{} - if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "Application"), namespace, name, app); err != nil { - return nil, err + if hasKnowledgeBaseRetriever { + qachain := &apichain.RetrievalQAChain{} + if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "RetrievalQAChain"), namespace, name, qachain); err != nil { + return nil, err + } + if qachain.UID != "" { + chainConfig = &qachain.Spec.CommonChainConfig + llmChainInput = &qachain.Spec.Input.LLMChainInput + } + retriever = &apiretriever.KnowledgeBaseRetriever{} + if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "KnowledgeBaseRetriever"), namespace, name, retriever); err != nil { + return nil, err + } + } else { + llmchain := &apichain.LLMChain{} + if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "LLMChain"), namespace, name, llmchain); err != nil { + return nil, err + } + if llmchain.UID != "" { + chainConfig = &llmchain.Spec.CommonChainConfig + llmChainInput = &llmchain.Spec.Input + } } return cr2app(prompt, chainConfig, llmChainInput, retriever, app) @@ -264,8 +276,6 @@ func ListApplicationMeatadatas(ctx context.Context, c dynamic.Interface, input g return res.Items[i].GetCreationTimestamp().After(res.Items[j].GetCreationTimestamp().Time) }) - totalCount := len(res.Items) - filterd := make([]generated.PageNode, 0) for _, u := range res.Items { if keyword != "" { @@ -280,6 +290,8 @@ func ListApplicationMeatadatas(ctx context.Context, c dynamic.Interface, input g } filterd = append(filterd, m) } + totalCount := len(filterd) + end := page * pageSize if end > totalCount { end = totalCount @@ -517,6 +529,7 @@ func UpdateApplicationConfig(ctx context.Context, c dynamic.Interface, input gen retriever.Spec.ScoreThreshold = float32(pointer.Float64Deref(input.ScoreThreshold, float64(retriever.Spec.ScoreThreshold))) retriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, retriever.Spec.NumDocuments) retriever.Spec.DocNullReturn = pointer.StringDeref(input.DocNullReturn, retriever.Spec.DocNullReturn) + retriever.Spec.Input.KnowledgeBaseRef.Name = *input.Knowledgebase }, retriever); err != nil { return nil, err } diff --git a/apiserver/pkg/auth/auth.go b/apiserver/pkg/auth/auth.go index d897f6d4f..9753bd056 100644 --- a/apiserver/pkg/auth/auth.go +++ b/apiserver/pkg/auth/auth.go @@ -36,7 +36,12 @@ import ( client1 "github.com/kubeagi/arcadia/apiserver/pkg/client" ) -type idtokenKey struct{} +type contextKey string + +const ( + idTokenContextKey contextKey = "idToken" + UserNameContextKey contextKey = "userName" +) type User struct { Name string `json:"name"` @@ -61,11 +66,11 @@ func isBearerToken(token string) (bool, string) { return head == "bearer" && len(payload) > 0, payload } -func cani(c dynamic.Interface, oidcToken *oidc.IDToken, resource, verb, namespace string) (bool, error) { +func cani(c dynamic.Interface, oidcToken *oidc.IDToken, resource, verb, namespace string) (bool, string, error) { u := &User{} if err := oidcToken.Claims(u); err != nil { klog.Errorf("parse user info from idToken, error %v", err) - return false, fmt.Errorf("can't parse user info") + return false, "", fmt.Errorf("can't parse user info") } av := av1.SubjectAccessReview{ @@ -87,15 +92,15 @@ func cani(c dynamic.Interface, oidcToken *oidc.IDToken, resource, verb, namespac if err != nil { err = fmt.Errorf("auth can-i failed, error %w", err) klog.Error(err) - return false, err + return false, "", err } ok, found, err := unstructured.NestedBool(u1.Object, "status", "allowed") if err != nil || !found { klog.Warning("not found allowed filed or some errors occurred.") - return false, err + return false, "", err } - return ok, nil + return ok, u.Name, nil } func AuthInterceptor(needAuth bool, oidcVerifier *oidc.IDTokenVerifier, verb, resources string) gin.HandlerFunc { @@ -133,7 +138,7 @@ func AuthInterceptor(needAuth bool, oidcVerifier *oidc.IDTokenVerifier, verb, re return } if verb != "" { - allowed, err := cani(client, oidcIDtoken, resources, verb, namespace) + allowed, userName, err := cani(client, oidcIDtoken, resources, verb, namespace) if err != nil { klog.Errorf("auth error: failed to checkout permission. error %s", err) ctx.AbortWithStatusJSON(http.StatusForbidden, gin.H{ @@ -148,17 +153,17 @@ func AuthInterceptor(needAuth bool, oidcVerifier *oidc.IDTokenVerifier, verb, re }) return } + ctx.Request = ctx.Request.WithContext(context.WithValue(ctx.Request.Context(), UserNameContextKey, userName)) } // for graphql query - ctx1 := context.WithValue(ctx.Request.Context(), idtokenKey{}, rawToken) - ctx.Request = ctx.Request.WithContext(ctx1) + ctx.Request = ctx.Request.WithContext(context.WithValue(ctx.Request.Context(), idTokenContextKey, rawToken)) ctx.Next() } } func ForOIDCToken(ctx context.Context) *string { - v, _ := ctx.Value(idtokenKey{}).(string) + v, _ := ctx.Value(idTokenContextKey).(string) if v == "" { return nil } diff --git a/apiserver/pkg/chat/chat.go b/apiserver/pkg/chat/chat.go index 50668e776..a43a4d780 100644 --- a/apiserver/pkg/chat/chat.go +++ b/apiserver/pkg/chat/chat.go @@ -19,6 +19,7 @@ package chat import ( "context" "errors" + "sync" "time" "github.com/tmc/langchaingo/memory" @@ -33,9 +34,13 @@ import ( "github.com/kubeagi/arcadia/apiserver/pkg/client" "github.com/kubeagi/arcadia/pkg/application" "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/application/retriever" ) -var Conversions = map[string]Conversion{} +var ( + mu sync.Mutex + Conversions = map[string]Conversion{} +) func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*ChatRespBody, error) { token := auth.ForOIDCToken(ctx) @@ -57,15 +62,22 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat return nil, errors.New("application is not ready") } var conversion Conversion + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) if req.ConversionID != "" { var ok bool conversion, ok = Conversions[req.ConversionID] if !ok { return nil, errors.New("conversion is not found") } + if currentUser != "" && currentUser != conversion.User { + return nil, errors.New("conversion id not match with user") + } if conversion.AppName != req.APPName || conversion.AppNamespce != req.AppNamespace { return nil, errors.New("conversion id not match with app info") } + if conversion.Debug != req.Debug { + return nil, errors.New("conversion id not match with debug") + } } else { conversion = Conversion{ ID: string(uuid.NewUUID()), @@ -75,10 +87,13 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat UpdatedAt: time.Now(), Messages: make([]Message, 0), History: memory.NewChatMessageHistory(), + User: currentUser, + Debug: req.Debug, } } + messageID := string(uuid.NewUUID()) conversion.Messages = append(conversion.Messages, Message{ - ID: string(uuid.NewUUID()), + ID: messageID, Query: req.Query, Answer: "", }) @@ -95,12 +110,71 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat conversion.UpdatedAt = time.Now() conversion.Messages[len(conversion.Messages)-1].Answer = out.Answer + conversion.Messages[len(conversion.Messages)-1].References = out.References + mu.Lock() Conversions[conversion.ID] = conversion + mu.Unlock() return &ChatRespBody{ ConversionID: conversion.ID, + MessageID: messageID, Message: out.Answer, CreatedAt: time.Now(), + References: out.References, }, nil } +func ListConversations(ctx context.Context, req APPMetadata) ([]Conversion, error) { + conversations := make([]Conversion, 0) + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + mu.Lock() + for _, c := range Conversions { + if !c.Debug && c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && (currentUser == "" || currentUser == c.User) { + conversations = append(conversations, c) + } + } + mu.Unlock() + return conversations, nil +} + +func DeleteConversation(ctx context.Context, conversionID string) error { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + mu.Lock() + defer mu.Unlock() + c, ok := Conversions[conversionID] + if ok && (currentUser == "" || currentUser == c.User) { + delete(Conversions, c.ID) + return nil + } else { + return errors.New("conversion is not found") + } +} + +func ListMessages(ctx context.Context, req ConversionReqBody) (Conversion, error) { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + mu.Lock() + defer mu.Unlock() + for _, c := range Conversions { + if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && req.ConversionID == c.ID && (currentUser == "" || currentUser == c.User) { + return c, nil + } + } + return Conversion{}, errors.New("conversion is not found") +} + +func GetMessageReferences(ctx context.Context, req MessageReqBody) ([]retriever.Reference, error) { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + mu.Lock() + defer mu.Unlock() + for _, c := range Conversions { + if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && c.ID == req.ConversionID && (currentUser == "" || currentUser == c.User) { + for _, m := range c.Messages { + if m.ID == req.MessageID { + return m.References, nil + } + } + } + } + return nil, errors.New("conversion or message is not found") +} + // todo Reuse the flow without having to rebuild req same, not finish, Flow doesn't start with/contain nodes that depend on incomingInput.question diff --git a/apiserver/pkg/chat/chat_type.go b/apiserver/pkg/chat/chat_type.go index 5d376bae1..2a45bfd34 100644 --- a/apiserver/pkg/chat/chat_type.go +++ b/apiserver/pkg/chat/chat_type.go @@ -20,6 +20,8 @@ import ( "time" "github.com/tmc/langchaingo/memory" + + "github.com/kubeagi/arcadia/pkg/application/retriever" ) type ResponseMode string @@ -30,33 +32,51 @@ const ( // todo isFlowValidForStream only some node(llm chain) support streaming ) +type APPMetadata struct { + APPName string `json:"app_name" binding:"required"` + AppNamespace string `json:"app_namespace" binding:"required"` +} + +type ConversionReqBody struct { + APPMetadata `json:",inline"` + ConversionID string `json:"conversion_id"` +} + +type MessageReqBody struct { + ConversionReqBody `json:",inline"` + MessageID string `json:"message_id"` +} + type ChatReqBody struct { - Query string `json:"query" binding:"required"` - ResponseMode ResponseMode `json:"response_mode" binding:"required"` - ConversionID string `json:"conversion_id"` - APPName string `json:"app_name" binding:"required"` - AppNamespace string `json:"app_namespace" binding:"required"` + Query string `json:"query" binding:"required"` + ResponseMode ResponseMode `json:"response_mode" binding:"required"` + ConversionReqBody `json:",inline"` + Debug bool `json:"-"` } type ChatRespBody struct { - ConversionID string `json:"conversion_id"` - MessageID string `json:"message_id"` - Message string `json:"message"` - CreatedAt time.Time `json:"created_at"` + ConversionID string `json:"conversion_id"` + MessageID string `json:"message_id"` + Message string `json:"message"` + CreatedAt time.Time `json:"created_at"` + References []retriever.Reference `json:"references,omitempty"` } type Conversion struct { - ID string `json:"id"` - AppName string `json:"app_name"` - AppNamespce string `json:"app_namespace"` - StartedAt time.Time `json:"started_at"` - UpdatedAt time.Time `json:"updated_at"` - Messages []Message `json:"messages"` - History *memory.ChatMessageHistory + ID string `json:"id"` + AppName string `json:"app_name"` + AppNamespce string `json:"app_namespace"` + StartedAt time.Time `json:"started_at"` + UpdatedAt time.Time `json:"updated_at"` + Messages []Message `json:"messages"` + History *memory.ChatMessageHistory `json:"-"` + User string `json:"-"` + Debug bool `json:"-"` } type Message struct { - ID string `json:"id"` - Query string `json:"query"` - Answer string `json:"answer"` + ID string `json:"id"` + Query string `json:"query"` + Answer string `json:"answer"` + References []retriever.Reference `json:"references,omitempty"` } diff --git a/apiserver/service/chat.go b/apiserver/service/chat.go index 928f3e1bf..ffa983bf4 100644 --- a/apiserver/service/chat.go +++ b/apiserver/service/chat.go @@ -13,18 +13,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package service import ( "io" "net/http" + "strings" "time" "github.com/gin-gonic/gin" "k8s.io/klog/v2" "github.com/kubeagi/arcadia/apiserver/config" + "github.com/kubeagi/arcadia/apiserver/pkg/auth" "github.com/kubeagi/arcadia/apiserver/pkg/chat" + "github.com/kubeagi/arcadia/apiserver/pkg/oidc" ) func chatHandler() gin.HandlerFunc { @@ -34,11 +38,13 @@ func chatHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + req.Debug = c.Query("debug") == "true" stream := req.ResponseMode == chat.Streaming var response *chat.ChatRespBody var err error if stream { + buf := strings.Builder{} // handle chat streaming mode respStream := make(chan string, 1) go func() { @@ -48,6 +54,9 @@ func chatHandler() gin.HandlerFunc { } }() response, err = chat.AppRun(c, req, respStream) + if response.Message == buf.String() { + close(respStream) + } }() // Use a ticker to check if there is no data arrived and close the stream @@ -84,6 +93,7 @@ func chatHandler() gin.HandlerFunc { CreatedAt: time.Now(), }) hasData = true + buf.WriteString(msg) return true } return false @@ -101,11 +111,91 @@ func chatHandler() gin.HandlerFunc { return } c.JSON(http.StatusOK, response) + } + } +} + +func listConversationHandler() gin.HandlerFunc { + return func(c *gin.Context) { + req := chat.APPMetadata{} + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + resp, err := chat.ListConversations(c, req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + klog.Infof("error resp: %v", err) + return + } + c.JSON(http.StatusOK, resp) + } +} + +func deleteConversationHandler() gin.HandlerFunc { + return func(c *gin.Context) { + conversionID := c.Param("conversionID") + if conversionID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversionID is required"}) + return + } + err := chat.DeleteConversation(c, conversionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + klog.Infof("error resp: %v", err) + return + } + c.JSON(http.StatusOK, gin.H{"message": "ok"}) + } +} + +func historyHandler() gin.HandlerFunc { + return func(c *gin.Context) { + req := chat.ConversionReqBody{} + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + resp, err := chat.ListMessages(c, req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + klog.Infof("error resp: %v", err) + return + } + c.JSON(http.StatusOK, resp) + } +} + +func referenceHandler() gin.HandlerFunc { + return func(c *gin.Context) { + messageID := c.Param("messageID") + if messageID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "messageID is required"}) return } + req := chat.MessageReqBody{ + MessageID: messageID, + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + resp, err := chat.GetMessageReferences(c, req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + klog.Infof("error resp: %v", err) + return + } + c.JSON(http.StatusOK, resp) } } -func RegisterChat(g *gin.Engine, conf config.ServerConfig) { - g.POST("/chat", chatHandler()) +func RegisterChat(g *gin.RouterGroup, conf config.ServerConfig) { + g.POST("", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), chatHandler()) // chat with bot + + g.POST("/conversations", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), listConversationHandler()) // list conversations + g.DELETE("/conversations/:conversionID", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), deleteConversationHandler()) // delete conversation + + g.POST("/messages", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), historyHandler()) // messages history + g.POST("/messages/:messageID/references", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), referenceHandler()) // messages reference } diff --git a/apiserver/service/router.go b/apiserver/service/router.go index 88592bc04..be0d6f8ee 100644 --- a/apiserver/service/router.go +++ b/apiserver/service/router.go @@ -51,9 +51,10 @@ func NewServerAndRun(conf config.ServerConfig) { } bffGroup := r.Group("/bff") + chatGroup := r.Group("/chat") RegisterMinIOAPI(bffGroup, conf) RegisterGraphQL(r, bffGroup, conf) - RegisterChat(r, conf) + RegisterChat(chatGroup, conf) _ = r.Run(fmt.Sprintf("%s:%d", conf.Host, conf.Port)) } diff --git a/config/samples/arcadia_v1alpha1_knowledgebase.yaml b/config/samples/arcadia_v1alpha1_knowledgebase.yaml index d73bc5dee..c0db4915e 100644 --- a/config/samples/arcadia_v1alpha1_knowledgebase.yaml +++ b/config/samples/arcadia_v1alpha1_knowledgebase.yaml @@ -12,7 +12,7 @@ spec: namespace: arcadia vectorStore: kind: VectorStores - name: chroma-sample + name: arcadia-vectorstore namespace: arcadia fileGroups: - source: diff --git a/pkg/application/app_run.go b/pkg/application/app_run.go index fadf92d78..bff008f3a 100644 --- a/pkg/application/app_run.go +++ b/pkg/application/app_run.go @@ -42,7 +42,8 @@ type Input struct { History langchaingoschema.ChatMessageHistory } type Output struct { - Answer string + Answer string + References []retriever.Reference } type Application struct { @@ -177,6 +178,11 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, respStream output = Output{Answer: answer} } } + if a, ok := out["_references"]; ok { + if references, ok := a.([]retriever.Reference); ok && len(references) > 0 { + output.References = references + } + } if output.Answer == "" && respStream == nil { return Output{}, errors.New("no answer") } diff --git a/pkg/application/chain/llmchain.go b/pkg/application/chain/llmchain.go index 748ac92af..e150b203f 100644 --- a/pkg/application/chain/llmchain.go +++ b/pkg/application/chain/llmchain.go @@ -29,6 +29,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" + "k8s.io/klog/v2" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" "github.com/kubeagi/arcadia/pkg/application/base" @@ -99,6 +100,7 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri out, err = chains.Predict(ctx, l.LLMChain, args) } } + klog.V(5).Infof("blocking out: %s", out) if err == nil { args["_answer"] = out } diff --git a/pkg/application/chain/retrievalqachain.go b/pkg/application/chain/retrievalqachain.go index af0f860eb..7fc696216 100644 --- a/pkg/application/chain/retrievalqachain.go +++ b/pkg/application/chain/retrievalqachain.go @@ -97,8 +97,10 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli dynamic.Interface, args llmChain := chains.NewLLMChain(llm, prompt) var baseChain chains.Chain + var stuffDocuments *appretriever.KnowledgeBaseStuffDocuments if knowledgeBaseRetriever, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok { - baseChain = appretriever.NewStuffDocuments(llmChain, knowledgeBaseRetriever.DocNullReturn) + stuffDocuments = appretriever.NewStuffDocuments(llmChain, knowledgeBaseRetriever.DocNullReturn) + baseChain = stuffDocuments } else { baseChain = chains.NewStuffDocuments(llmChain) } @@ -116,7 +118,10 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli dynamic.Interface, args out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args) } } - klog.Infof("out:%v, err:%s", out, err) + if stuffDocuments != nil && len(stuffDocuments.References) > 0 { + args["_references"] = stuffDocuments.References + } + klog.V(5).Infof("blocking out: %s", out) if err == nil { args["_answer"] = out } diff --git a/pkg/application/retriever/knowledgebaseretriever.go b/pkg/application/retriever/knowledgebaseretriever.go index 6241ee4fe..83214fc31 100644 --- a/pkg/application/retriever/knowledgebaseretriever.go +++ b/pkg/application/retriever/knowledgebaseretriever.go @@ -19,6 +19,8 @@ package retriever import ( "context" "fmt" + "strconv" + "strings" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/chains" @@ -37,6 +39,14 @@ import ( "github.com/kubeagi/arcadia/pkg/langchainwrap" ) +type Reference struct { + Question string `json:"question"` + Answer string `json:"answer"` + Score float32 `json:"score"` + FilePath string `json:"file_path"` + LineNumber int `json:"line_number"` +} + type KnowledgeBaseRetriever struct { langchaingoschema.Retriever base.BaseNode @@ -130,16 +140,17 @@ type KnowledgeBaseStuffDocuments struct { isDocNullReturn bool DocNullReturn string callbacks.SimpleHandler + References []Reference } -var _ chains.Chain = KnowledgeBaseStuffDocuments{} -var _ callbacks.Handler = KnowledgeBaseStuffDocuments{} +var _ chains.Chain = &KnowledgeBaseStuffDocuments{} +var _ callbacks.Handler = &KnowledgeBaseStuffDocuments{} func (c *KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Document) string { var text string docLen := len(docs) for k, doc := range docs { - klog.Infof("KnowledgeBaseRetriever: related doc[%d] raw text: %s, raw score: %v\n", k, doc.PageContent, doc.Score) + klog.Infof("KnowledgeBaseRetriever: related doc[%d] raw text: %s, raw score: %f\n", k, doc.PageContent, doc.Score) for key, v := range doc.Metadata { if str, ok := v.([]byte); ok { klog.Infof("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %s\n", k, key, string(str)) @@ -147,15 +158,24 @@ func (c *KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Doc klog.Infof("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %#v\n", k, key, v) } } - answer := doc.Metadata["a"] - answerBytes, _ := answer.([]byte) + answer, _ := doc.Metadata["a"].([]byte) text += doc.PageContent - if len(answerBytes) != 0 { - text = text + "\na: " + string(answerBytes) + if len(answer) != 0 { + text = text + "\na: " + strings.TrimPrefix(strings.TrimSuffix(string(answer), "\""), "\"") } if k != docLen-1 { text += c.Separator } + filepath, _ := doc.Metadata["fileName"].([]byte) + lineNumber, _ := doc.Metadata["lineNumber"].([]byte) + line, _ := strconv.Atoi(string(lineNumber)) + c.References = append(c.References, Reference{ + Question: doc.PageContent, + Answer: strings.TrimPrefix(strings.TrimSuffix(string(answer), "\""), "\""), + Score: doc.Score, + FilePath: strings.TrimPrefix(strings.TrimSuffix(string(filepath), "\""), "\""), + LineNumber: line, + }) } klog.Infof("KnowledgeBaseRetriever: finally get related text: %s\n", text) if len(text) == 0 { @@ -164,14 +184,15 @@ func (c *KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Doc return text } -func NewStuffDocuments(llmChain *chains.LLMChain, docNullReturn string) KnowledgeBaseStuffDocuments { - return KnowledgeBaseStuffDocuments{ +func NewStuffDocuments(llmChain *chains.LLMChain, docNullReturn string) *KnowledgeBaseStuffDocuments { + return &KnowledgeBaseStuffDocuments{ StuffDocuments: chains.NewStuffDocuments(llmChain), DocNullReturn: docNullReturn, + References: make([]Reference, 0, 5), } } -func (c KnowledgeBaseStuffDocuments) Call(ctx context.Context, values map[string]any, options ...chains.ChainCallOption) (map[string]any, error) { +func (c *KnowledgeBaseStuffDocuments) Call(ctx context.Context, values map[string]any, options ...chains.ChainCallOption) (map[string]any, error) { docs, ok := values[c.InputKey].([]langchaingoschema.Document) if !ok { return nil, fmt.Errorf("%w: %w", chains.ErrInvalidInputValues, chains.ErrInputValuesWrongType) diff --git a/tests/example-test.sh b/tests/example-test.sh index 5a1535056..3b5a76305 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -182,12 +182,14 @@ function getRespInAppChat() { data=$(jq -n --arg appname "$appname" --arg query "$query" --arg namespace "$namespace" --arg conversionID "$conversionID" '{"query":$query,"response_mode":"blocking","conversion_id":$conversionID,"app_name":$appname, "app_namespace":$namespace}') resp=$(curl -s -XPOST http://127.0.0.1:8081/chat --data "$data") ai_data=$(echo $resp | jq -r '.message') + references=$(echo $resp | jq -r '.references') if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then echo $resp exit 1 fi echo "👤: ${query}" echo "🤖: ${ai_data}" + echo "🔗: ${references}" resp_conversion_id=$(echo $resp | jq -r '.conversion_id') if [ $testStream == "true" ]; then @@ -216,7 +218,7 @@ kind load docker-image controller:example-e2e --name=$KindName info "3. install arcadia" kubectl create namespace arcadia helm install -narcadia arcadia deploy/charts/arcadia -f tests/deploy-values.yaml \ - --set controller.image=controller:example-e2e --set apiserver.image=controller:example-e2e \ + --set controller.image=controller:example-e2e --set apiserver.image=controller:example-e2e \ --wait --timeout $HelmTimeout info "4. check system datasource arcadia-minio(system datasource)" @@ -231,18 +233,8 @@ if [[ $datasourceType != "oss" ]]; then exit 1 fi -info "6. create and verify vectorstore" -info "6.1. helm install chroma" -helm repo add chroma https://amikos-tech.github.io/chromadb-chart/ -helm repo update chroma -if [[ $GITHUB_ACTIONS == "true" ]]; then - helm install -narcadia chroma chroma/chromadb --set service.type=ClusterIP --set chromadb.auth.enabled=false --wait --timeout $HelmTimeout -else - helm install -narcadia chroma chroma/chromadb --set service.type=ClusterIP --set chromadb.auth.enabled=false --wait --timeout $HelmTimeout --set image.repository=docker.io/abirdcfly/chroma -fi -info "6.2. verify chroma vectorstore status" -kubectl apply -f config/samples/arcadia_v1alpha1_vectorstore.yaml -waitCRDStatusReady "VectorStore" "arcadia" "chroma-sample" +info "6. verify default vectorstore" +waitCRDStatusReady "VectorStore" "arcadia" "arcadia-vectorstore" info "7. create and verify knowledgebase" @@ -273,7 +265,7 @@ kubectl apply -f config/samples/arcadia_v1alpha1_knowledgebase.yaml waitCRDStatusReady "KnowledgeBase" "arcadia" "knowledgebase-sample" info "7.5 check this vectorstore has data" -kubectl port-forward -n arcadia svc/chroma-chromadb 8000:8000 >/dev/null 2>&1 & +kubectl port-forward -n arcadia svc/arcadia-chromadb 8000:8000 >/dev/null 2>&1 & chroma_pid=$! info "port-forward chroma in pid: $chroma_pid" sleep 3 @@ -315,6 +307,11 @@ if [[ $resp != *"Jim"* ]]; then echo "Because conversionWindowSize is enabled to be 2, llm should record history, but resp:"$resp "dont contains Jim" exit 1 fi + +info "8.4 check conversion list and message history" +curl -XPOST http://127.0.0.1:8081/chat/conversations --data '{"app_name": "base-chat-with-bot", "app_namespace": "arcadia"}' +data=$(jq -n --arg conversionID "$resp_conversion_id" '{"conversion_id":$conversionID, "app_name": "base-chat-with-bot", "app_namespace": "arcadia"}') +curl -XPOST http://127.0.0.1:8081/chat/messages --data "$data" # There is uncertainty in the AI replies, most of the time, it will pass the test, a small percentage of the time, the AI will call names in each reply, causing the test to fail, therefore, temporarily disable the following tests #getRespInAppChat "base-chat-with-bot" "arcadia" "What is your model?" ${resp_conversion_id} "false" #getRespInAppChat "base-chat-with-bot" "arcadia" "Does your model based on gpt-3.5?" ${resp_conversion_id} "false"