From e19cd7b30c6147fd3f9fd38a493f0f7da7853c42 Mon Sep 17 00:00:00 2001 From: bjwswang Date: Fri, 12 Jan 2024 10:02:54 +0000 Subject: [PATCH] feat: generate evaluation test dataset Signed-off-by: bjwswang --- apiserver/pkg/chat/chat.go | 10 +- apiserver/pkg/chat/chat_type.go | 2 +- ...dia_v1alpha1_worker_bge-large-zh-v1.5.yaml | 6 +- .../app_run.go => appruntime/app_runtime.go} | 16 +-- .../base/context.go | 0 pkg/{application => appruntime}/base/input.go | 0 pkg/{application => appruntime}/base/node.go | 0 .../base/output.go | 0 .../chain/common.go | 0 .../chain/llmchain.go | 2 +- .../chain/retrievalqachain.go | 4 +- .../knowledgebase/knowledgebase.go | 2 +- pkg/{application => appruntime}/llm/llm.go | 2 +- .../prompt/prompt.go | 2 +- .../retriever/knowledgebaseretriever.go | 11 +- pkg/arctl/eval.go | 109 +++++++++++++++++ pkg/evaluation/evaluation.go | 114 +++++++++++++++++- pkg/evaluation/output.go | 30 +++++ 18 files changed, 285 insertions(+), 25 deletions(-) rename pkg/{application/app_run.go => appruntime/app_runtime.go} (93%) rename pkg/{application => appruntime}/base/context.go (100%) rename pkg/{application => appruntime}/base/input.go (100%) rename pkg/{application => appruntime}/base/node.go (100%) rename pkg/{application => appruntime}/base/output.go (100%) rename pkg/{application => appruntime}/chain/common.go (100%) rename pkg/{application => appruntime}/chain/llmchain.go (98%) rename pkg/{application => appruntime}/chain/retrievalqachain.go (97%) rename pkg/{application => appruntime}/knowledgebase/knowledgebase.go (94%) rename pkg/{application => appruntime}/llm/llm.go (97%) rename pkg/{application => appruntime}/prompt/prompt.go (97%) rename pkg/{application => appruntime}/retriever/knowledgebaseretriever.go (97%) create mode 100644 pkg/arctl/eval.go create mode 100644 pkg/evaluation/output.go diff --git a/apiserver/pkg/chat/chat.go b/apiserver/pkg/chat/chat.go index 8771572e9..68112a649 100644 --- a/apiserver/pkg/chat/chat.go +++ b/apiserver/pkg/chat/chat.go @@ -33,9 +33,9 @@ import ( "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/apiserver/pkg/auth" "github.com/kubeagi/arcadia/apiserver/pkg/client" - "github.com/kubeagi/arcadia/pkg/application" - "github.com/kubeagi/arcadia/pkg/application/base" - "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) var ( @@ -99,12 +99,12 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat Answer: "", }) ctx = base.SetAppNamespace(ctx, req.AppNamespace) - appRun, err := application.NewAppOrGetFromCache(ctx, app, c) + appRun, err := appruntime.NewAppOrGetFromCache(ctx, c, app) if err != nil { return nil, err } klog.FromContext(ctx).Info("begin to run application", "appName", req.APPName, "appNamespace", req.AppNamespace) - out, err := appRun.Run(ctx, c, respStream, application.Input{Question: req.Query, NeedStream: req.ResponseMode.IsStreaming(), History: conversation.History}) + out, err := appRun.Run(ctx, c, respStream, appruntime.Input{Question: req.Query, NeedStream: req.ResponseMode.IsStreaming(), History: conversation.History}) if err != nil { return nil, err } diff --git a/apiserver/pkg/chat/chat_type.go b/apiserver/pkg/chat/chat_type.go index a710e6cab..348c1ee73 100644 --- a/apiserver/pkg/chat/chat_type.go +++ b/apiserver/pkg/chat/chat_type.go @@ -21,7 +21,7 @@ import ( "github.com/tmc/langchaingo/memory" - "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type ResponseMode string diff --git a/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml b/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml index 8cf01d91e..845da0f13 100644 --- a/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml +++ b/config/samples/arcadia_v1alpha1_worker_bge-large-zh-v1.5.yaml @@ -1,8 +1,8 @@ -\apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Worker metadata: name: bge-large-zh - namespace: arcadia + namespace: kubeagi-system spec: displayName: BGE模型服务 description: "这是一个Embedding模型服务,由BGE提供" @@ -10,4 +10,4 @@ spec: replicas: 1 model: kind: "Models" - name: "bge-large-zh-v1.5" + name: "bge-large-zh-v1.5" \ No newline at end of file diff --git a/pkg/application/app_run.go b/pkg/appruntime/app_runtime.go similarity index 93% rename from pkg/application/app_run.go rename to pkg/appruntime/app_runtime.go index 4189aee3a..f6e0740ce 100644 --- a/pkg/application/app_run.go +++ b/pkg/appruntime/app_runtime.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package application +package appruntime import ( "container/list" @@ -28,12 +28,12 @@ import ( "k8s.io/utils/strings/slices" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" - "github.com/kubeagi/arcadia/pkg/application/chain" - "github.com/kubeagi/arcadia/pkg/application/knowledgebase" - "github.com/kubeagi/arcadia/pkg/application/llm" - "github.com/kubeagi/arcadia/pkg/application/prompt" - "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/appruntime/chain" + "github.com/kubeagi/arcadia/pkg/appruntime/knowledgebase" + "github.com/kubeagi/arcadia/pkg/appruntime/llm" + "github.com/kubeagi/arcadia/pkg/appruntime/prompt" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type Input struct { @@ -62,7 +62,7 @@ type Application struct { // return app.Namespace + "/" + app.Name //} -func NewAppOrGetFromCache(ctx context.Context, app *arcadiav1alpha1.Application, cli dynamic.Interface) (*Application, error) { +func NewAppOrGetFromCache(ctx context.Context, cli dynamic.Interface, app *arcadiav1alpha1.Application) (*Application, error) { if app == nil || app.Name == "" || app.Namespace == "" { return nil, errors.New("app has no name or namespace") } diff --git a/pkg/application/base/context.go b/pkg/appruntime/base/context.go similarity index 100% rename from pkg/application/base/context.go rename to pkg/appruntime/base/context.go diff --git a/pkg/application/base/input.go b/pkg/appruntime/base/input.go similarity index 100% rename from pkg/application/base/input.go rename to pkg/appruntime/base/input.go diff --git a/pkg/application/base/node.go b/pkg/appruntime/base/node.go similarity index 100% rename from pkg/application/base/node.go rename to pkg/appruntime/base/node.go diff --git a/pkg/application/base/output.go b/pkg/appruntime/base/output.go similarity index 100% rename from pkg/application/base/output.go rename to pkg/appruntime/base/output.go diff --git a/pkg/application/chain/common.go b/pkg/appruntime/chain/common.go similarity index 100% rename from pkg/application/chain/common.go rename to pkg/appruntime/chain/common.go diff --git a/pkg/application/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go similarity index 98% rename from pkg/application/chain/llmchain.go rename to pkg/appruntime/chain/llmchain.go index 3ece671e6..9de8d57bd 100644 --- a/pkg/application/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -32,7 +32,7 @@ import ( "k8s.io/klog/v2" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) type LLMChain struct { diff --git a/pkg/application/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go similarity index 97% rename from pkg/application/chain/retrievalqachain.go rename to pkg/appruntime/chain/retrievalqachain.go index f94dd3ad5..d846f1954 100644 --- a/pkg/application/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -32,8 +32,8 @@ import ( "k8s.io/klog/v2" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" - appretriever "github.com/kubeagi/arcadia/pkg/application/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + appretriever "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type RetrievalQAChain struct { diff --git a/pkg/application/knowledgebase/knowledgebase.go b/pkg/appruntime/knowledgebase/knowledgebase.go similarity index 94% rename from pkg/application/knowledgebase/knowledgebase.go rename to pkg/appruntime/knowledgebase/knowledgebase.go index 8074e337d..38d8290d0 100644 --- a/pkg/application/knowledgebase/knowledgebase.go +++ b/pkg/appruntime/knowledgebase/knowledgebase.go @@ -21,7 +21,7 @@ import ( "k8s.io/client-go/dynamic" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) type Knowledgebase struct { diff --git a/pkg/application/llm/llm.go b/pkg/appruntime/llm/llm.go similarity index 97% rename from pkg/application/llm/llm.go rename to pkg/appruntime/llm/llm.go index 3846658b9..c6ce4f158 100644 --- a/pkg/application/llm/llm.go +++ b/pkg/appruntime/llm/llm.go @@ -27,7 +27,7 @@ import ( "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/langchainwrap" ) diff --git a/pkg/application/prompt/prompt.go b/pkg/appruntime/prompt/prompt.go similarity index 97% rename from pkg/application/prompt/prompt.go rename to pkg/appruntime/prompt/prompt.go index a5644ed6b..11a9b9c3a 100644 --- a/pkg/application/prompt/prompt.go +++ b/pkg/appruntime/prompt/prompt.go @@ -27,7 +27,7 @@ import ( "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/api/app-node/prompt/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) type Prompt struct { diff --git a/pkg/application/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go similarity index 97% rename from pkg/application/retriever/knowledgebaseretriever.go rename to pkg/appruntime/retriever/knowledgebaseretriever.go index 0bc3b1a43..da244e32d 100644 --- a/pkg/application/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,6 +18,7 @@ package retriever import ( "context" + "encoding/json" "fmt" "strconv" "strings" @@ -34,7 +35,7 @@ import ( apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/pkg/application/base" + "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/langchainwrap" pkgvectorstore "github.com/kubeagi/arcadia/pkg/vectorstore" ) @@ -52,6 +53,14 @@ type Reference struct { LineNumber int `json:"line_number" example:"7"` } +func (reference Reference) String() string { + bytes, err := json.Marshal(&reference) + if err != nil { + return "" + } + return string(bytes) +} + type KnowledgeBaseRetriever struct { langchaingoschema.Retriever base.BaseNode diff --git a/pkg/arctl/eval.go b/pkg/arctl/eval.go new file mode 100644 index 000000000..27cf5f2e7 --- /dev/null +++ b/pkg/arctl/eval.go @@ -0,0 +1,109 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package arctl + +import ( + "bytes" + "context" + "io" + "os" + + basev1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/apiserver/graph/generated" + "github.com/kubeagi/arcadia/apiserver/pkg/common" + "github.com/kubeagi/arcadia/pkg/evaluation" + "github.com/spf13/cobra" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/dynamic" +) + +func NewEvalCmd(kubeClient dynamic.Interface, namespace string) *cobra.Command { + var app string + + cmd := &cobra.Command{ + Use: "eval", + Short: "Manage evaluations", + } + + cmd.Flags().StringVar(&app, "app", "", "The application that is going to be evaluated") + + cmd.AddCommand(EvalGenTestDataset(kubeClient, namespace, app)) + + return cmd +} + +func EvalGenTestDataset(kubeClient dynamic.Interface, namespace string, appName string) *cobra.Command { + var file string + var questionColumn string + var groundTruthsColumn string + + cmd := &cobra.Command{ + Use: "gen_test_dataset", + Short: "Generate a test dataset for evaluation", + RunE: func(cmd *cobra.Command, args []string) error { + // read file content + f, err := os.Open(file) + if err != nil { + return err + } + data, err := io.ReadAll(f) + if err != nil { + return err + } + + // read files + app := &basev1alpha1.Application{} + obj, err := common.ResouceGet(context.Background(), kubeClient, generated.TypedObjectReferenceInput{ + APIGroup: &common.ArcadiaAPIGroup, + Kind: "Application", + Namespace: &namespace, + Name: appName, + }, v1.GetOptions{}) + if err != nil { + return err + } + + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), app) + if err != nil { + return err + } + + generator, err := evaluation.NewRagasDatasetGenerator(context.Background(), kubeClient, app) + if err != nil { + return err + } + + _, err = generator.Generate( + context.Background(), + bytes.NewReader(data), + evaluation.WithQuestionColumn(questionColumn), + evaluation.WithGroundTruthsColumn(groundTruthsColumn), + ) + if err != nil { + return err + } + return nil + }, + } + + cmd.Flags().StringVar(&file, "file", "", "The file(CSV) which provides question and ground_truths") + cmd.Flags().StringVar(&questionColumn, "question-column", "q", "The column name which provides questions") + cmd.Flags().StringVar(&groundTruthsColumn, "ground-truths-column", "a", "The column name which provides the answers") + + return cmd +} diff --git a/pkg/evaluation/evaluation.go b/pkg/evaluation/evaluation.go index f209b535f..ee327562d 100644 --- a/pkg/evaluation/evaluation.go +++ b/pkg/evaluation/evaluation.go @@ -16,4 +16,116 @@ limitations under the License. package evaluation -// TO BE DEFINED +import ( + "context" + "io" + + "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + pkgdocumentloaders "github.com/kubeagi/arcadia/pkg/documentloaders" + "k8s.io/client-go/dynamic" +) + +type RagasDataRow struct { + Question string `json:"question"` + GroundTruths string `json:"ground_truths"` + Contexts []string `json:"contexts"` + Answer string `json:"answer"` +} + +type RagasDatasetGenerator struct { + cli dynamic.Interface + + app appruntime.Application +} + +func NewRagasDatasetGenerator(ctx context.Context, cli dynamic.Interface, app *v1alpha1.Application) (*RagasDatasetGenerator, error) { + runapp, err := appruntime.NewAppOrGetFromCache(ctx, cli, app) + if err != nil { + return nil, err + } + return &RagasDatasetGenerator{cli: cli, app: *runapp}, nil +} + +type genOptions struct { + // questionColumn in csv file + questionColumn string + groundTruthsColumn string + + output Output +} + +func defaultGenOptions() *genOptions { + return &genOptions{ + questionColumn: "q", + groundTruthsColumn: "a", + output: &PrintOutput{}, + } +} + +func WithQuestionColumn(questionColumn string) GenOptions { + return func(genOpts *genOptions) { + genOpts.questionColumn = questionColumn + } +} + +func WithGroundTruthsColumn(groundTruthsColumn string) GenOptions { + return func(genOpts *genOptions) { + genOpts.groundTruthsColumn = groundTruthsColumn + } +} + +func WithOutput(output Output) GenOptions { + return func(genOpts *genOptions) { + genOpts.output = output + } +} + +type GenOptions func(*genOptions) + +// Generate a test dataset from a file(csv) +func (eval *RagasDatasetGenerator) Generate(ctx context.Context, csvData io.Reader, genOptions ...GenOptions) ([]RagasDataRow, error) { + ctx = base.SetAppNamespace(ctx, eval.app.Namespace) + + // set generation options + genOpts := defaultGenOptions() + for _, o := range genOptions { + o(genOpts) + } + + // load csv to langchain documents + loader := pkgdocumentloaders.NewQACSV(csvData, "", genOpts.questionColumn, genOpts.groundTruthsColumn) + langchainDocuments, err := loader.Load(ctx) + if err != nil { + return nil, err + } + + // convert langchain documents to ragas dataset + var ragasRows = make([]RagasDataRow, len(langchainDocuments)) + for docIndex, doc := range langchainDocuments { + ragasRow := RagasDataRow{ + Question: doc.PageContent, + GroundTruths: doc.Metadata[genOpts.groundTruthsColumn].(string), + } + + // chat with application + out, err := eval.app.Run(ctx, eval.cli, nil, appruntime.Input{Question: ragasRow.Question, NeedStream: false, History: nil}) + if err != nil { + return nil, err + } + ragasRow.Answer = out.Answer + + // handle context + contexts := make([]string, len(out.References)) + for refIndex, reference := range out.References { + contexts[refIndex] = reference.String() + } + ragasRow.Contexts = contexts + + // set ragasRows + ragasRows[docIndex] = ragasRow + } + + return ragasRows, nil +} diff --git a/pkg/evaluation/output.go b/pkg/evaluation/output.go new file mode 100644 index 000000000..26a6fb1fe --- /dev/null +++ b/pkg/evaluation/output.go @@ -0,0 +1,30 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evaluation + +import "fmt" + +type Output interface { + Output(RagasDataRow) error +} + +type PrintOutput struct{} + +func (print *PrintOutput) Output(row RagasDataRow) error { + fmt.Printf("question:%s \n ground_truths:%s \n answer:%s \n contexts:%v \n", row.Question, row.GroundTruths, row.Answer, row.Contexts[:]) + return nil +}