diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go index 3c28a9fb4..a0499e806 100644 --- a/apiserver/docs/docs.go +++ b/apiserver/docs/docs.go @@ -1280,6 +1280,70 @@ const docTemplate = `{ } } } + }, + "/rags/scatter": { + "get": { + "description": "Get scatter data of a rag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "RAG" + ], + "summary": "Get scatter data of a rag", + "parameters": [ + { + "type": "string", + "description": "rag name", + "name": "ragName", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Name of the bucket", + "name": "namespace", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "application name", + "name": "appName", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/rag.ReportDetail" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } } }, "definitions": { @@ -1567,12 +1631,6 @@ const docTemplate = `{ "$ref": "#/definitions/rag.RadarData" } }, - "scatterChart": { - "type": "array", - "items": { - "$ref": "#/definitions/rag.ScatterData" - } - }, "summary": { "description": "TODO", "type": "string" @@ -1631,20 +1689,6 @@ const docTemplate = `{ } } }, - "rag.ScatterData": { - "type": "object", - "properties": { - "color": { - "type": "string" - }, - "score": { - "type": "number" - }, - "type": { - "type": "string" - } - } - }, "rag.TotalScoreData": { "type": "object", "properties": { diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json index a0ceab04c..d1d3340a2 100644 --- a/apiserver/docs/swagger.json +++ b/apiserver/docs/swagger.json @@ -1274,6 +1274,70 @@ } } } + }, + "/rags/scatter": { + "get": { + "description": "Get scatter data of a rag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "RAG" + ], + "summary": "Get scatter data of a rag", + "parameters": [ + { + "type": "string", + "description": "rag name", + "name": "ragName", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Name of the bucket", + "name": "namespace", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "application name", + "name": "appName", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/rag.ReportDetail" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } } }, "definitions": { @@ -1561,12 +1625,6 @@ "$ref": "#/definitions/rag.RadarData" } }, - "scatterChart": { - "type": "array", - "items": { - "$ref": "#/definitions/rag.ScatterData" - } - }, "summary": { "description": "TODO", "type": "string" @@ -1625,20 +1683,6 @@ } } }, - "rag.ScatterData": { - "type": "object", - "properties": { - "color": { - "type": "string" - }, - "score": { - "type": "number" - }, - "type": { - "type": "string" - } - } - }, "rag.TotalScoreData": { "type": "object", "properties": { diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml index aae1ff929..7c9e015ad 100644 --- a/apiserver/docs/swagger.yaml +++ b/apiserver/docs/swagger.yaml @@ -207,10 +207,6 @@ definitions: items: $ref: '#/definitions/rag.RadarData' type: array - scatterChart: - items: - $ref: '#/definitions/rag.ScatterData' - type: array summary: description: TODO type: string @@ -249,15 +245,6 @@ definitions: totalScore: type: number type: object - rag.ScatterData: - properties: - color: - type: string - score: - type: number - type: - type: string - type: object rag.TotalScoreData: properties: color: @@ -1348,6 +1335,49 @@ paths: summary: Get a summary of rag tags: - RAG + /rags/scatter: + get: + consumes: + - application/json + description: Get scatter data of a rag + parameters: + - description: rag name + in: query + name: ragName + required: true + type: string + - description: Name of the bucket + in: header + name: namespace + required: true + type: string + - description: application name + in: query + name: appName + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/rag.ReportDetail' + "400": + description: Bad Request + schema: + additionalProperties: + type: string + type: object + "500": + description: Internal Server Error + schema: + additionalProperties: + type: string + type: object + summary: Get scatter data of a rag + tags: + - RAG securityDefinitions: ApiKeyAuth: description: API token for authorization diff --git a/apiserver/pkg/rag/report.go b/apiserver/pkg/rag/report.go index a370282bc..1cc9a6399 100644 --- a/apiserver/pkg/rag/report.go +++ b/apiserver/pkg/rag/report.go @@ -20,6 +20,7 @@ import ( "encoding/csv" "fmt" "io" + "math" "sort" "strconv" "strings" @@ -35,9 +36,16 @@ import ( const ( totalScore = "total_score" + csvLatencyField = "latency" + csvQuestionField = "question" + csvGroundTruthsField = "ground_truths" + csvAnswerField = "answer" + csvContextsField = "contexts" + + epsilon = 1e-6 // TODO: support for color change via env - blueColorEnv = "BLUE_ENV" - blue = "blue" // 散点图的颜色 + // blueColorEnv = "BLUE_ENV" + // blue = "blue" // 散点图的颜色 orangeEnv = "ORANGE_RNV" orange = "orange" // 差 @@ -52,6 +60,15 @@ const ( ) var ( + // NOTE: if other fields are added in the Generate Test Data section, they need to be updated here as well. + csvBasicFields = map[string]struct{}{ + csvQuestionField: {}, + csvGroundTruthsField: {}, + csvAnswerField: {}, + csvContextsField: {}, + csvLatencyField: {}, + } + metricChinese = map[string]string{ string(v1alpha1.AnswerRelevancy): "答案相关度", string(v1alpha1.AnswerSimilarity): "答案相似度", @@ -70,7 +87,7 @@ var ( string(v1alpha1.Faithfulness): "调整模型配置或更换模型", string(v1alpha1.ContextPrecision): "调整 Embedding 模型", string(v1alpha1.ContextRelevancy): "调整 Embedding 模型", - string(v1alpha1.ContextRecall): "调整 QA 数据", // 知识库相似度? + string(v1alpha1.ContextRecall): "调整 QA 数据", string(v1alpha1.AspectCritique): "暂时没用到", } ) @@ -88,22 +105,20 @@ type ( } ScatterData struct { - Score float64 `json:"score"` - Type string `json:"type"` - Color string `json:"color"` + Score float64 `json:"score"` + CostTime float64 `json:"costTime"` } Report struct { - RadarChart []RadarData `json:"radarChart"` - TotalScore TotalScoreData `json:"totalScore"` - ScatterChart []ScatterData `json:"scatterChart"` + RadarChart []RadarData `json:"radarChart"` + TotalScore TotalScoreData `json:"totalScore"` // TODO Summary string `json:"summary"` } // 忠实度、答案相关度、答案语义相似度、答案正确性、知识库相关度、知识库精度、知识库相似度 - // question,ground_truths,answer,contexts + // question,ground_truths,answer,contexts,latency ReportLine struct { Question string `json:"question"` GroundTruths []string `json:"groundTruths"` @@ -137,9 +152,8 @@ func ParseSummary( return Report{}, err } csvReader := csv.NewReader(object) - report := Report{TotalScore: TotalScoreData{}, RadarChart: []RadarData{}, ScatterChart: []ScatterData{}} + report := Report{TotalScore: TotalScoreData{}, RadarChart: []RadarData{}} radarChecker := make(map[string]int) - scatterChecker := make(map[string]int) changeTotalScoreColor := false @@ -187,15 +201,6 @@ func ParseSummary( metricSuggesstion = append(metricSuggesstion, suggestionChinese[line[0]]) changeTotalScoreColor = true } - - nextScatterIndex := len(report.ScatterChart) - idx, ok = scatterChecker[line[0]] - if !ok { - scatterChecker[line[0]] = nextScatterIndex - report.ScatterChart = append(report.ScatterChart, ScatterData{Type: line[0], Color: blue}) - idx = nextScatterIndex - } - report.ScatterChart[idx].Score = score } if changeTotalScoreColor { @@ -207,6 +212,64 @@ func ParseSummary( return report, nil } +func PraseScatterChart(ctx context.Context, c client.Client, appName, ragName, namespace string) ([]ScatterData, error) { + source, err := common.SystemDatasourceOSS(ctx, c) + if err != nil { + klog.Errorf("failed to get system datasource error %s", err) + return nil, err + } + + filePath := fmt.Sprintf("evals/%s/%s/result.csv", appName, ragName) + object, err := source.Client.GetObject(ctx, namespace, filePath, minio.GetObjectOptions{}) + if err != nil { + klog.Errorf("failed to get result.csv file error %s", err) + return nil, err + } + csvReader := csv.NewReader(object) + data, err := csvReader.ReadAll() + if err != nil { + klog.Error("failed to read csv error %s", err) + return nil, err + } + + extra := make([]int, 0) + header := data[0] + latencyIndex := 0 + for i := 1; i < len(header); i++ { + if header[i] == csvLatencyField { + latencyIndex = i + continue + } + if _, ok := csvBasicFields[header[i]]; !ok { + extra = append(extra, i) + } + } + + result := make([]ScatterData, 0) + if len(extra) == 0 { + return result, nil + } + + for _, line := range data[1:] { + costTime, _ := strconv.ParseFloat(line[latencyIndex], 64) + sum := float64(0) + for _, index := range extra { + f, _ := strconv.ParseFloat(line[index], 64) + sum += f + } + score := sum / float64(len(extra)) + result = append(result, ScatterData{CostTime: costTime, Score: score}) + } + + sort.SliceStable(result, func(i, j int) bool { + if math.Abs(result[i].CostTime-result[j].CostTime) < epsilon { + return result[i].Score < result[j].Score + } + return result[i].CostTime < result[j].CostTime + }) + return result, nil +} + func ParseResult( ctx context.Context, c client.Client, page, pageSize int, @@ -240,26 +303,39 @@ func ParseResult( return ReportDetail{}, nil } + extra := make([]int, 0) result := make([]ReportLine, len(data)-1) + csvBasicFieldIndies := make(map[string]int) header := data[0] + for i := 1; i < len(header); i++ { + _, ok := csvBasicFields[header[i]] + if ok { + csvBasicFieldIndies[header[i]] = i + continue + } + extra = append(extra, i) + } + if len(extra) == 0 { + return ReportDetail{}, nil + } + for i, line := range data[1:] { item := ReportLine{ - Question: line[1], - GroundTruths: []string{line[2]}, - Answer: line[3], - Contexts: []string{line[4]}, + Question: line[csvBasicFieldIndies[csvQuestionField]], + GroundTruths: []string{line[csvBasicFieldIndies[csvGroundTruthsField]]}, + Answer: line[csvBasicFieldIndies[csvAnswerField]], + Contexts: []string{line[csvBasicFieldIndies[csvContextsField]]}, Data: make(map[string]float64), } - item.CostTime, _ = strconv.ParseFloat(line[5], 64) + item.CostTime, _ = strconv.ParseFloat(line[csvBasicFieldIndies[csvLatencyField]], 64) sum := float64(0) - // TODO: Avoid direct hardcode. Mapping index via map - for i := 6; i < len(line); i++ { - f, _ := strconv.ParseFloat(line[i], 64) - item.Data[header[i]] = f + for _, idx := range extra { + f, _ := strconv.ParseFloat(line[idx], 64) + item.Data[header[idx]] = f sum += f } - item.TotalScore = sum / float64(len(line)-6) + item.TotalScore = sum / float64(len(extra)) result[i] = item } diff --git a/apiserver/service/rag_server.go b/apiserver/service/rag_server.go index 5ce1b4ef9..a473ca941 100644 --- a/apiserver/service/rag_server.go +++ b/apiserver/service/rag_server.go @@ -51,14 +51,12 @@ const ( // @Produce json // @Param ragName query string true "rag name" // @Param namespace header string true "Name of the bucket" -// @Param appName query string true "application name" // @Success 200 {object} rag.Report // @Failure 400 {object} map[string]string // @Failure 500 {object} map[string]string // @Router /rags/report [get] func (r *RagAPI) Summary(ctx *gin.Context) { ragName := ctx.Query(ragNameQuery) - appName := ctx.Query(appNameQuery) namespace := ctx.GetHeader(namespaceHeadr) rr := v1alpha1.RAG{} @@ -76,7 +74,7 @@ func (r *RagAPI) Summary(ctx *gin.Context) { thresholds[string(param.Kind)] = float64(param.ToleranceThreshbold) / 100.0 } - report, err := rag.ParseSummary(ctx.Request.Context(), r.c, appName, ragName, namespace, thresholds) + report, err := rag.ParseSummary(ctx.Request.Context(), r.c, rr.Spec.Application.Name, ragName, namespace, thresholds) if err != nil { klog.Errorf("an error occurred generating the report, error %s", err) ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ @@ -95,7 +93,6 @@ func (r *RagAPI) Summary(ctx *gin.Context) { // @Produce json // @Param ragName query string true "rag name" // @Param namespace header string true "Name of the bucket" -// @Param appName query string true "application name" // @Param page query int false "default is 1" // @Param size query string false "default is 10" // @Param sortBy query string false "rag metrcis" @@ -110,10 +107,19 @@ func (r *RagAPI) ReportDetail(ctx *gin.Context) { sortBy := ctx.Query("sortBy") order := ctx.DefaultQuery("order", "desc") ragName := ctx.Query(ragNameQuery) - appName := ctx.Query(appNameQuery) namespace := ctx.GetHeader(namespaceHeadr) - result, err := rag.ParseResult(ctx.Request.Context(), r.c, page, pageSize, appName, ragName, namespace, sortBy, order) + rr := v1alpha1.RAG{} + if err := r.c.Get(ctx, types.NamespacedName{ + Namespace: namespace, Name: ragName, + }, &rr); err != nil { + klog.Error(fmt.Sprintf("can't get rag by name %s", ragName)) + ctx.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ + "message": fmt.Sprintf("can't get rag by name %s", ragName), + }) + return + } + result, err := rag.ParseResult(ctx.Request.Context(), r.c, page, pageSize, rr.Spec.Application.Name, ragName, namespace, sortBy, order) if err != nil { ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ "message": err.Error(), @@ -123,6 +129,44 @@ func (r *RagAPI) ReportDetail(ctx *gin.Context) { ctx.JSON(http.StatusOK, result) } +// @Summary Get scatter data of a rag +// @Schemes +// @Description Get scatter data of a rag +// @Tags RAG +// @Accept json +// @Produce json +// @Param ragName query string true "rag name" +// @Param namespace header string true "Name of the bucket" +// @Success 200 {object} rag.ReportDetail +// @Failure 400 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /rags/scatter [get] +func (r *RagAPI) ReportScatter(ctx *gin.Context) { + ragName := ctx.Query(ragNameQuery) + namespace := ctx.GetHeader(namespaceHeadr) + + rr := v1alpha1.RAG{} + if err := r.c.Get(ctx, types.NamespacedName{ + Namespace: namespace, Name: ragName, + }, &rr); err != nil { + klog.Error(fmt.Sprintf("can't get rag by name %s", ragName)) + ctx.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ + "message": fmt.Sprintf("can't get rag by name %s", ragName), + }) + return + } + result, err := rag.PraseScatterChart(ctx.Request.Context(), r.c, rr.Spec.Application.Name, ragName, namespace) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ + "message": err.Error(), + }) + return + } + ctx.JSON(http.StatusOK, gin.H{ + "data": result, + }) +} + func registerRAG(g *gin.RouterGroup, conf gqlconfig.ServerConfig) { cfg := ctrl.GetConfigOrDie() c, err := client.New(cfg, client.Options{Scheme: conf.Scheme}) @@ -133,4 +177,5 @@ func registerRAG(g *gin.RouterGroup, conf gqlconfig.ServerConfig) { g.GET("/report", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.Summary) g.GET("/detail", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.ReportDetail) + g.GET("/scatter", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.ReportScatter) }