diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index c5273a163..bcd1c19ad 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -130,6 +130,7 @@ type ComplexityRoot struct { Description func(childComplexity int) int Enable func(childComplexity int) int FileProgress func(childComplexity int) int + LlmConfig func(childComplexity int) int Name func(childComplexity int) int Preview func(childComplexity int) int ZhName func(childComplexity int) int @@ -373,6 +374,17 @@ type ComplexityRoot struct { UpdateTimestamp func(childComplexity int) int } + LLMConfig struct { + MaxTokens func(childComplexity int) int + Model func(childComplexity int) int + Name func(childComplexity int) int + Namespace func(childComplexity int) int + PromptTemplate func(childComplexity int) int + Provider func(childComplexity int) int + Temperature func(childComplexity int) int + TopP func(childComplexity int) int + } + LLMQuery struct { GetLlm func(childComplexity int, name string, namespace string) int ListLLMs func(childComplexity int, input ListCommonInput) int @@ -1051,6 +1063,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.DataProcessConfigChildren.FileProgress(childComplexity), true + case "DataProcessConfigChildren.llm_config": + if e.complexity.DataProcessConfigChildren.LlmConfig == nil { + break + } + + return e.complexity.DataProcessConfigChildren.LlmConfig(childComplexity), true + case "DataProcessConfigChildren.name": if e.complexity.DataProcessConfigChildren.Name == nil { break @@ -2306,6 +2325,62 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.LLM.UpdateTimestamp(childComplexity), true + case "LLMConfig.max_tokens": + if e.complexity.LLMConfig.MaxTokens == nil { + break + } + + return e.complexity.LLMConfig.MaxTokens(childComplexity), true + + case "LLMConfig.model": + if e.complexity.LLMConfig.Model == nil { + break + } + + return e.complexity.LLMConfig.Model(childComplexity), true + + case "LLMConfig.name": + if e.complexity.LLMConfig.Name == nil { + break + } + + return e.complexity.LLMConfig.Name(childComplexity), true + + case "LLMConfig.namespace": + if e.complexity.LLMConfig.Namespace == nil { + break + } + + return e.complexity.LLMConfig.Namespace(childComplexity), true + + case "LLMConfig.prompt_template": + if e.complexity.LLMConfig.PromptTemplate == nil { + break + } + + return e.complexity.LLMConfig.PromptTemplate(childComplexity), true + + case "LLMConfig.provider": + if e.complexity.LLMConfig.Provider == nil { + break + } + + return e.complexity.LLMConfig.Provider(childComplexity), true + + case "LLMConfig.temperature": + if e.complexity.LLMConfig.Temperature == nil { + break + } + + return e.complexity.LLMConfig.Temperature(childComplexity), true + + case "LLMConfig.top_p": + if e.complexity.LLMConfig.TopP == nil { + break + } + + return e.complexity.LLMConfig.TopP(childComplexity), true + case "LLMQuery.getLLM": if e.complexity.LLMQuery.GetLlm == nil { break @@ -4027,10 +4102,22 @@ type DataProcessConfigChildren { enable: String zh_name: String description: String + llm_config: LLMConfig preview: [DataProcessConfigpreView] file_progress: [DataProcessConfigpreFileProgress] } +type LLMConfig { + name: String + namespace: String + model: String + temperature: String + top_p: String + max_tokens: String + prompt_template: String + provider: String +} + # 数据处理配置项预览 type DataProcessConfigpreView { file_name: String @@ -8579,6 +8666,8 @@ func (ec *executionContext) fieldContext_DataProcessConfig_children(ctx context. return ec.fieldContext_DataProcessConfigChildren_zh_name(ctx, field) case "description": return ec.fieldContext_DataProcessConfigChildren_description(ctx, field) + case "llm_config": + return ec.fieldContext_DataProcessConfigChildren_llm_config(ctx, field) case "preview": return ec.fieldContext_DataProcessConfigChildren_preview(ctx, field) case "file_progress": @@ -8754,6 +8843,65 @@ func (ec *executionContext) fieldContext_DataProcessConfigChildren_description(c return fc, nil } +func (ec *executionContext) _DataProcessConfigChildren_llm_config(ctx context.Context, field graphql.CollectedField, obj *DataProcessConfigChildren) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_DataProcessConfigChildren_llm_config(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.LlmConfig, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*LLMConfig) + fc.Result = res + return ec.marshalOLLMConfig2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐLLMConfig(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_DataProcessConfigChildren_llm_config(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "DataProcessConfigChildren", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "name": + return ec.fieldContext_LLMConfig_name(ctx, field) + case "namespace": + return ec.fieldContext_LLMConfig_namespace(ctx, field) + case "model": + return ec.fieldContext_LLMConfig_model(ctx, field) + case "temperature": + return ec.fieldContext_LLMConfig_temperature(ctx, field) + case "top_p": + return ec.fieldContext_LLMConfig_top_p(ctx, field) + case "max_tokens": + return ec.fieldContext_LLMConfig_max_tokens(ctx, field) + case "prompt_template": + return ec.fieldContext_LLMConfig_prompt_template(ctx, field) + case "provider": + return ec.fieldContext_LLMConfig_provider(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type LLMConfig", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _DataProcessConfigChildren_preview(ctx context.Context, field graphql.CollectedField, obj *DataProcessConfigChildren) (ret graphql.Marshaler) { fc, err := ec.fieldContext_DataProcessConfigChildren_preview(ctx, field) if err != nil { @@ -16410,6 +16558,334 @@ func (ec *executionContext) fieldContext_LLM_message(ctx context.Context, field return fc, nil } +func (ec *executionContext) _LLMConfig_name(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_name(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Name, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_name(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_namespace(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_namespace(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Namespace, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_namespace(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_model(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_model(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Model, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_model(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_temperature(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_temperature(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Temperature, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_temperature(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_top_p(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_top_p(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.TopP, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_top_p(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_max_tokens(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_max_tokens(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.MaxTokens, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_max_tokens(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_prompt_template(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_prompt_template(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.PromptTemplate, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_prompt_template(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LLMConfig_provider(ctx context.Context, field graphql.CollectedField, obj *LLMConfig) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LLMConfig_provider(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Provider, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LLMConfig_provider(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LLMConfig", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _LLMQuery_getLLM(ctx context.Context, field graphql.CollectedField, obj *LLMQuery) (ret graphql.Marshaler) { fc, err := ec.fieldContext_LLMQuery_getLLM(ctx, field) if err != nil { @@ -29092,6 +29568,8 @@ func (ec *executionContext) _DataProcessConfigChildren(ctx context.Context, sel out.Values[i] = ec._DataProcessConfigChildren_zh_name(ctx, field, obj) case "description": out.Values[i] = ec._DataProcessConfigChildren_description(ctx, field, obj) + case "llm_config": + out.Values[i] = ec._DataProcessConfigChildren_llm_config(ctx, field, obj) case "preview": out.Values[i] = ec._DataProcessConfigChildren_preview(ctx, field, obj) case "file_progress": @@ -31417,6 +31895,56 @@ func (ec *executionContext) _LLM(ctx context.Context, sel ast.SelectionSet, obj return out } +var lLMConfigImplementors = []string{"LLMConfig"} + +func (ec *executionContext) _LLMConfig(ctx context.Context, sel ast.SelectionSet, obj *LLMConfig) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, lLMConfigImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("LLMConfig") + case "name": + out.Values[i] = ec._LLMConfig_name(ctx, field, obj) + case "namespace": + out.Values[i] = ec._LLMConfig_namespace(ctx, field, obj) + case "model": + out.Values[i] = ec._LLMConfig_model(ctx, field, obj) + case "temperature": + out.Values[i] = ec._LLMConfig_temperature(ctx, field, obj) + case "top_p": + out.Values[i] = ec._LLMConfig_top_p(ctx, field, obj) + case "max_tokens": + out.Values[i] = ec._LLMConfig_max_tokens(ctx, field, obj) + case "prompt_template": + out.Values[i] = ec._LLMConfig_prompt_template(ctx, field, obj) + case "provider": + out.Values[i] = ec._LLMConfig_provider(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var lLMQueryImplementors = []string{"LLMQuery"} func (ec *executionContext) _LLMQuery(ctx context.Context, sel ast.SelectionSet, obj *LLMQuery) graphql.Marshaler { @@ -35282,6 +35810,13 @@ func (ec *executionContext) marshalOLLM2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋap return ec._LLM(ctx, sel, v) } +func (ec *executionContext) marshalOLLMConfig2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐLLMConfig(ctx context.Context, sel ast.SelectionSet, v *LLMConfig) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._LLMConfig(ctx, sel, v) +} + func (ec *executionContext) unmarshalOLLMConfigItem2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐLLMConfigItem(ctx context.Context, v interface{}) (*LLMConfigItem, error) { if v == nil { return nil, nil diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 2cc6d096b..618491c50 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -357,6 +357,7 @@ type DataProcessConfigChildren struct { Enable *string `json:"enable,omitempty"` ZhName *string `json:"zh_name,omitempty"` Description *string `json:"description,omitempty"` + LlmConfig *LLMConfig `json:"llm_config,omitempty"` Preview []*DataProcessConfigpreView `json:"preview,omitempty"` FileProgress []*DataProcessConfigpreFileProgress `json:"file_progress,omitempty"` } @@ -780,6 +781,17 @@ type Llm struct { func (Llm) IsPageNode() {} +type LLMConfig struct { + Name *string `json:"name,omitempty"` + Namespace *string `json:"namespace,omitempty"` + Model *string `json:"model,omitempty"` + Temperature *string `json:"temperature,omitempty"` + TopP *string `json:"top_p,omitempty"` + MaxTokens *string `json:"max_tokens,omitempty"` + PromptTemplate *string `json:"prompt_template,omitempty"` + Provider *string `json:"provider,omitempty"` +} + type LLMConfigItem struct { Name *string `json:"name,omitempty"` Namespace *string `json:"namespace,omitempty"` diff --git a/apiserver/graph/schema/dataprocessing.gql b/apiserver/graph/schema/dataprocessing.gql index 71fcce0dd..1ffd931ac 100644 --- a/apiserver/graph/schema/dataprocessing.gql +++ b/apiserver/graph/schema/dataprocessing.gql @@ -72,6 +72,16 @@ query dataProcessDetails($input: DataProcessDetailsInput){ enable zh_name description + llm_config { + name + namespace + model + temperature + top_p + max_tokens + prompt_template + provider + } preview { file_name content { diff --git a/apiserver/graph/schema/dataprocessing.graphqls b/apiserver/graph/schema/dataprocessing.graphqls index 188ab3b94..131945d16 100644 --- a/apiserver/graph/schema/dataprocessing.graphqls +++ b/apiserver/graph/schema/dataprocessing.graphqls @@ -180,10 +180,22 @@ type DataProcessConfigChildren { enable: String zh_name: String description: String + llm_config: LLMConfig preview: [DataProcessConfigpreView] file_progress: [DataProcessConfigpreFileProgress] } +type LLMConfig { + name: String + namespace: String + model: String + temperature: String + top_p: String + max_tokens: String + prompt_template: String + provider: String +} + # 数据处理配置项预览 type DataProcessConfigpreView { file_name: String diff --git a/data-processing/Dockerfile b/data-processing/Dockerfile index f995bbb19..41e4ce3a0 100644 --- a/data-processing/Dockerfile +++ b/data-processing/Dockerfile @@ -17,6 +17,9 @@ RUN wget https://github.com/explosion/spacy-models/releases/download/zh_core_web && pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i https://pypi.org/simple \ && rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl +ENV DEFAULT_CONFIG=arcadia-config +ENV POD_NAMESPACE=arcadia + EXPOSE 28888 ADD . /arcadia_app/ diff --git a/data-processing/data_manipulation/data_store_process/minio_store_process.py b/data-processing/data_manipulation/data_store_process/minio_store_process.py index ac00e3942..1916c72dc 100644 --- a/data-processing/data_manipulation/data_store_process/minio_store_process.py +++ b/data-processing/data_manipulation/data_store_process/minio_store_process.py @@ -16,12 +16,13 @@ import io import logging import os +import ulid import pandas as pd from common import log_tag_const from common.config import config from data_store_clients import minio_store_client -from database_operate import data_process_db_operate +from database_operate import data_process_db_operate, data_process_document_db_operate from file_handle import csv_handle, pdf_handle, word_handle from kube import dataset_cr from utils import file_utils @@ -64,13 +65,31 @@ def text_manipulate( file_name=file_name['name'] ) + # 将文件信息存入data_process_task_document表中 + for file_name in file_names: + # 新增文档处理进度信息 + document_id = ulid.ulid() + document_insert_item = { + 'id': document_id, + 'task_id': id, + 'file_name': file_name['name'], + 'status': 'not_start', + 'progress': '0', + 'creator': req_json['creator'] + } + data_process_document_db_operate.add( + document_insert_item, + pool=pool + ) + file_name['document_id']=document_id + # 文件处理 task_status = 'process_complete' # 存放每个文件对应的数据量 data_volumes_file = [] for item in file_names: - result = [] + result = None file_name = item['name'] file_extension = file_name.split('.')[-1].lower() @@ -87,6 +106,7 @@ def text_manipulate( chunk_size=req_json.get('chunk_size'), chunk_overlap=req_json.get('chunk_overlap'), file_name=file_name, + document_id=document_id, support_type=support_type, conn_pool=pool, task_id=id, @@ -99,13 +119,14 @@ def text_manipulate( chunk_size=req_json.get('chunk_size'), chunk_overlap=req_json.get('chunk_overlap'), file_name=file_name, + document_id=document_id, support_type=support_type, conn_pool=pool, task_id=id, create_user=req_json['creator'] ) - if result.get('status') != 200: + if result is None or result.get('status') != 200: # 任务失败 task_status = 'process_fail' break diff --git a/data-processing/data_manipulation/database_operate/data_process_detail_db_operate.py b/data-processing/data_manipulation/database_operate/data_process_detail_db_operate.py index 5cb548d4b..649fe9f5f 100644 --- a/data-processing/data_manipulation/database_operate/data_process_detail_db_operate.py +++ b/data-processing/data_manipulation/database_operate/data_process_detail_db_operate.py @@ -13,7 +13,6 @@ # limitations under the License. -import ulid from database_clients import postgresql_pool_client from utils import date_time_utils diff --git a/data-processing/data_manipulation/database_operate/data_process_document_db_operate.py b/data-processing/data_manipulation/database_operate/data_process_document_db_operate.py new file mode 100644 index 000000000..194a820e6 --- /dev/null +++ b/data-processing/data_manipulation/database_operate/data_process_document_db_operate.py @@ -0,0 +1,185 @@ +# Copyright 2023 KubeAGI. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ulid +from database_clients import postgresql_pool_client +from utils import date_time_utils + +def add( + req_json, + pool +): + """Add a new record""" + now = date_time_utils.now_str() + user = req_json['creator'] + program = '数据处理文件进度-新增' + + params = { + 'id': req_json['id'], + 'file_name': req_json['file_name'], + 'status': req_json['status'], + 'progress': req_json['progress'], + 'task_id': req_json['task_id'], + 'create_datetime': now, + 'create_user': user, + 'create_program': program, + 'update_datetime': now, + 'update_user': user, + 'update_program': program + } + + sql = """ + insert into public.data_process_task_document ( + id, + file_name, + status, + progress, + task_id, + create_datetime, + create_user, + create_program, + update_datetime, + update_user, + update_program + ) + values ( + %(id)s, + %(file_name)s, + %(status)s, + %(progress)s, + %(task_id)s, + %(create_datetime)s, + %(create_user)s, + %(create_program)s, + %(update_datetime)s, + %(update_user)s, + %(update_program)s + ) + """.strip() + + res = postgresql_pool_client.execute_update(pool, sql, params) + return res + +def update_document_status_and_start_time( + req_json, + pool +): + """Update the status and start time with id""" + now = req_json['start_time'] + program = '文件开始处理-修改' + + params = { + 'id': req_json['id'], + 'status': req_json['status'], + 'start_time': now, + 'chunk_size': req_json['chunk_size'], + 'update_datetime': now, + 'update_program': program + } + + sql = """ + update public.data_process_task_document set + status = %(status)s, + start_time = %(start_time)s, + chunk_size = %(chunk_size)s, + update_datetime = %(update_datetime)s, + update_program = %(update_program)s + where + id = %(id)s + """.strip() + + res = postgresql_pool_client.execute_update(pool, sql, params) + return res + +def update_document_status_and_end_time( + req_json, + pool +): + """Update the status and start time with id""" + now = req_json['end_time'] + program = '文件处理完成-修改' + + params = { + 'id': req_json['id'], + 'status': req_json['status'], + 'end_time': now, + 'update_datetime': now, + 'update_program': program + } + + sql = """ + update public.data_process_task_document set + status = %(status)s, + end_time = %(end_time)s, + update_datetime = %(update_datetime)s, + update_program = %(update_program)s + where + id = %(id)s + """.strip() + + res = postgresql_pool_client.execute_update(pool, sql, params) + return res + +def update_document_progress( + req_json, + pool +): + """Update the progress with id""" + now = date_time_utils.now_str() + program = '文件处理进度-修改' + + params = { + 'id': req_json['id'], + 'progress': req_json['progress'], + 'update_datetime': now, + 'update_program': program + } + + sql = """ + update public.data_process_task_document set + progress = %(progress)s, + update_datetime = %(update_datetime)s, + update_program = %(update_program)s + where + id = %(id)s + """.strip() + + res = postgresql_pool_client.execute_update(pool, sql, params) + return res + +def list_file_by_task_id( + req_json, + pool +): + """info with id""" + params = { + 'task_id': req_json['task_id'] + } + + sql = """ + select + id, + file_name, + status, + start_time, + end_time, + progress + from + public.data_process_task_document + where + task_id = %(task_id)s + """.strip() + + res = postgresql_pool_client.execute_query(pool, sql, params) + return res diff --git a/data-processing/data_manipulation/file_handle/common_handle.py b/data-processing/data_manipulation/file_handle/common_handle.py index 14697f00b..cf89d6ce5 100644 --- a/data-processing/data_manipulation/file_handle/common_handle.py +++ b/data-processing/data_manipulation/file_handle/common_handle.py @@ -22,12 +22,12 @@ import ulid from common import log_tag_const from common.config import config -from database_operate import data_process_detail_db_operate +from database_operate import data_process_detail_db_operate, data_process_document_db_operate from langchain.text_splitter import SpacyTextSplitter from llm_api_service.qa_provider_open_ai import QAProviderOpenAI from llm_api_service.qa_provider_zhi_pu_ai_online import QAProviderZhiPuAIOnline from transform.text import clean_transform, privacy_transform -from utils import csv_utils, file_utils, docx_utils +from utils import csv_utils, file_utils, docx_utils, date_time_utils from kube import model_cr logger = logging.getLogger(__name__) @@ -35,6 +35,7 @@ def text_manipulate( file_name, + document_id, content, support_type, conn_pool, @@ -97,9 +98,9 @@ def text_manipulate( chunk_size=chunk_size, chunk_overlap=chunk_overlap, data=content, - name=llm_config.get('name'), - namespace=llm_config.get('namespace'), - model=llm_config.get('model') + document_id=document_id, + conn_pool=conn_pool, + llm_config=llm_config ) if qa_response.get('status') != 200: @@ -530,9 +531,9 @@ def _generate_qa_list( chunk_size, chunk_overlap, data, - name, - namespace, - model + document_id, + conn_pool, + llm_config ): """Generate the Question and Answer list. @@ -543,6 +544,14 @@ def _generate_qa_list( namespace: llms cr namespace; model: model id or model version; """ + name=llm_config.get('name') + namespace=llm_config.get('namespace') + model=llm_config.get('model') + temperature=llm_config.get('temperature') + prompt_template=llm_config.get('prompt_template') + top_p=llm_config.get('top_p') + max_tokens=llm_config.get('max_tokens') + # Split the text. if chunk_size is None: chunk_size = config.knowledge_chunk_size @@ -563,6 +572,13 @@ def _generate_qa_list( f"splitted text is: \n{texts}\n" ])) + # 更新文件状态为开始 + _update_document_status_and_start_time( + id=document_id, + texts=texts, + conn_pool=conn_pool + ) + # llms cr 中模型相关信息 llm_spec_info = model_cr.get_spec_for_llms_k8s_cr( name=name, @@ -578,7 +594,7 @@ def _generate_qa_list( namespace=config.k8s_pod_namespace ) logger.debug(''.join([ - f"Generate the QA list \n", + f"worker llm \n", f"name: {name}\n", f"namespace: {namespace}\n", f"model: {model}\n", @@ -589,16 +605,47 @@ def _generate_qa_list( qa_provider = QAProviderOpenAI( api_key='fake', base_url=base_url, - model=model + model=model, + temperature=temperature, + max_tokens=max_tokens ) + # text process success number + text_process_success_num=0 + for item in texts: text = item.replace("\n", "") - data = qa_provider.generate_qa_list(text) + data = qa_provider.generate_qa_list( + text=text, + prompt_template=prompt_template + ) if data.get('status') != 200: + # 文件处理失败,更新data_process_task_document中的文件状态 + _updata_document_status_and_end_time( + id=document_id, + status='fail', + conn_pool=conn_pool + ) + return data qa_list.extend(data.get('data')) + + # 更新文件处理进度 + text_process_success_num += 1 + progress = int(text_process_success_num / len(texts) * 100) + _updata_document_progress( + id=document_id, + progress=progress, + conn_pool=conn_pool + ) + + # 文件处理成功,更新data_process_task_document中的文件状态 + _updata_document_status_and_end_time( + id=document_id, + status='success', + conn_pool=conn_pool + ) else: endpoint = llm_spec_info.get('data').get('provider').get('endpoint') base_url = endpoint.get('url') @@ -612,19 +659,58 @@ def _generate_qa_list( api_key = secret_info.get('apiKey') llm_type = llm_spec_info.get('data').get('type') - if llm_type == 'zhipuai': + logger.debug(''.join([ + f"3rd_party llm \n", + f"name: {name}\n", + f"namespace: {namespace}\n", + f"model: {model}\n", + f"llm_type: {llm_type}\n" + ])) + + if llm_type == 'zhipuai': zhipuai_api_key = base64.b64decode(api_key).decode('utf-8') qa_provider = QAProviderZhiPuAIOnline(api_key=zhipuai_api_key) - + # text process success number + text_process_success_num=0 + # generate QA list for item in texts: text = item.replace("\n", "") - data = qa_provider.generate_qa_list(text) + data = qa_provider.generate_qa_list( + text=text, + model=model, + prompt_template=prompt_template, + top_p=top_p, + temperature=temperature + ) if data.get('status') != 200: + # 文件处理失败,更新data_process_task_document中的文件状态 + _updata_document_status_and_end_time( + id=document_id, + status='fail', + conn_pool=conn_pool + ) + return data qa_list.extend(data.get('data')) + + # 更新文件处理进度 + text_process_success_num += 1 + progress = int(text_process_success_num / len(texts) * 100) + _updata_document_progress( + id=document_id, + progress=progress, + conn_pool=conn_pool + ) + + # 文件处理成功,更新data_process_task_document中的文件状态 + _updata_document_status_and_end_time( + id=document_id, + status='success', + conn_pool=conn_pool + ) else: return { 'status': 1000, @@ -663,4 +749,104 @@ def _convert_support_type_to_map(supprt_type): for item in supprt_type: result[item['type']] = item - return result \ No newline at end of file + return result + +def _update_document_status_and_start_time( + id, + texts, + conn_pool +): + try: + now = date_time_utils.now_str() + document_update_item = { + 'id': id, + 'status': 'doing', + 'start_time': now, + 'chunk_size': len(texts) + } + data_process_document_db_operate.update_document_status_and_start_time( + document_update_item, + pool=conn_pool + ) + + return { + 'status': 200, + 'message': '', + 'data': '' + } + except Exception as ex: + logger.error(''.join([ + f"{log_tag_const.COMMON_HANDLE} update document status ", + f"\n{traceback.format_exc()}" + ])) + return { + 'status': 1000, + 'message': str(ex), + 'data': traceback.format_exc() + } + +def _updata_document_status_and_end_time( + id, + status, + conn_pool +): + try: + now = date_time_utils.now_str() + document_update_item = { + 'id': id, + 'status': status, + 'end_time': now + } + data_process_document_db_operate.update_document_status_and_end_time( + document_update_item, + pool=conn_pool + ) + + return { + 'status': 200, + 'message': '', + 'data': '' + } + except Exception as ex: + logger.error(''.join([ + f"{log_tag_const.COMMON_HANDLE} update document status ", + f"\n{traceback.format_exc()}" + ])) + return { + 'status': 1000, + 'message': str(ex), + 'data': traceback.format_exc() + } + +def _updata_document_progress( + id, + progress, + conn_pool +): + try: + now = date_time_utils.now_str() + document_update_item = { + 'id': id, + 'progress': progress + } + data_process_document_db_operate.update_document_progress( + document_update_item, + pool=conn_pool + ) + + return { + 'status': 200, + 'message': '', + 'data': '' + } + except Exception as ex: + logger.error(''.join([ + f"{log_tag_const.COMMON_HANDLE} update document progress ", + f"\n{traceback.format_exc()}" + ])) + return { + 'status': 1000, + 'message': str(ex), + 'data': traceback.format_exc() + } + diff --git a/data-processing/data_manipulation/file_handle/pdf_handle.py b/data-processing/data_manipulation/file_handle/pdf_handle.py index ba19959d9..8ea024098 100644 --- a/data-processing/data_manipulation/file_handle/pdf_handle.py +++ b/data-processing/data_manipulation/file_handle/pdf_handle.py @@ -25,6 +25,7 @@ def text_manipulate( file_name, + document_id, support_type, conn_pool, task_id, @@ -55,6 +56,7 @@ def text_manipulate( response = common_handle.text_manipulate( file_name=file_name, + document_id=document_id, content=content, support_type=support_type, conn_pool=conn_pool, diff --git a/data-processing/data_manipulation/file_handle/word_handle.py b/data-processing/data_manipulation/file_handle/word_handle.py index 2e5fbdd86..42e3f9436 100644 --- a/data-processing/data_manipulation/file_handle/word_handle.py +++ b/data-processing/data_manipulation/file_handle/word_handle.py @@ -25,6 +25,7 @@ def docx_text_manipulate( file_name, + document_id, support_type, conn_pool, task_id, @@ -55,6 +56,7 @@ def docx_text_manipulate( response = common_handle.text_manipulate( file_name=file_name, + document_id=document_id, content=content, support_type=support_type, conn_pool=conn_pool, diff --git a/data-processing/data_manipulation/kube/minio_cr.py b/data-processing/data_manipulation/kube/minio_cr.py index f6d8ef7be..0cf484d2c 100644 --- a/data-processing/data_manipulation/kube/minio_cr.py +++ b/data-processing/data_manipulation/kube/minio_cr.py @@ -53,8 +53,10 @@ def get_minio_config_in_k8s_configmap( minio_api_url = minio_cr_object['spec']['endpoint']['url'] minio_secure = True - insecure_str = str(minio_cr_object['spec']['endpoint']['insecure']) - if insecure_str == 'true': + insecure = minio_cr_object['spec']['endpoint'].get('insecure') + if insecure is None: + minio_secure = True + elif str(insecure).lower() == 'true': minio_secure = False diff --git a/data-processing/data_manipulation/kube/model_cr.py b/data-processing/data_manipulation/kube/model_cr.py index 192741008..ed55bab19 100644 --- a/data-processing/data_manipulation/kube/model_cr.py +++ b/data-processing/data_manipulation/kube/model_cr.py @@ -14,6 +14,7 @@ import logging import yaml +import traceback from utils import date_time_utils diff --git a/data-processing/data_manipulation/llm_api_service/qa_provider_open_ai.py b/data-processing/data_manipulation/llm_api_service/qa_provider_open_ai.py index 3cfa7ceef..0ebfec490 100644 --- a/data-processing/data_manipulation/llm_api_service/qa_provider_open_ai.py +++ b/data-processing/data_manipulation/llm_api_service/qa_provider_open_ai.py @@ -26,7 +26,7 @@ ChatPromptTemplate, HumanMessagePromptTemplate, ) -from llm_prompt_template import open_ai_prompt +from llm_prompt_template import llm_prompt from .base_qa_provider import BaseQAProvider @@ -39,14 +39,21 @@ def __init__( self, api_key, base_url, - model + model, + temperature=None, + max_tokens=None ): - # TODO: temperature and top_p/top_k should be configured later + if temperature is None: + temperature = 0.8 + if max_tokens is None: + max_tokens = 512 + self.llm = ChatOpenAI( openai_api_key=api_key, base_url=base_url, model=model, - temperature=0.8 + temperature=temperature, + max_tokens=max_tokens ) def generate_qa_list( @@ -64,7 +71,7 @@ def generate_qa_list( the prompt template """ if prompt_template is None: - prompt_template = open_ai_prompt.get_default_prompt_template() + prompt_template = llm_prompt.get_default_prompt_template() human_message_prompt = HumanMessagePromptTemplate.from_template(prompt_template) prompt = ChatPromptTemplate.from_messages([human_message_prompt]) diff --git a/data-processing/data_manipulation/llm_api_service/qa_provider_zhi_pu_ai_online.py b/data-processing/data_manipulation/llm_api_service/qa_provider_zhi_pu_ai_online.py index c6982f84a..f8c7091e1 100644 --- a/data-processing/data_manipulation/llm_api_service/qa_provider_zhi_pu_ai_online.py +++ b/data-processing/data_manipulation/llm_api_service/qa_provider_zhi_pu_ai_online.py @@ -21,7 +21,7 @@ import zhipuai from common import log_tag_const from common.config import config -from llm_prompt_template import zhi_pu_ai_prompt +from llm_prompt_template import llm_prompt from .base_qa_provider import BaseQAProvider @@ -32,15 +32,16 @@ class QAProviderZhiPuAIOnline(BaseQAProvider): """The QA provider is used by zhi pu ai online.""" def __init__(self, api_key=None): - if api_key is None: - api_key = config.zhipuai_api_key zhipuai.api_key = api_key def generate_qa_list( self, text, - prompt_template=None + model, + prompt_template=None, + top_p=None, + temperature=None ): """Generate the QA list. @@ -52,7 +53,11 @@ def generate_qa_list( the prompt template """ if prompt_template is None: - prompt_template = zhi_pu_ai_prompt.get_default_prompt_template() + prompt_template = llm_prompt.get_default_prompt_template() + if top_p is None: + top_p = 0.7 + if temperature is None: + temperature = 0.8 content = prompt_template.format( text=text @@ -79,12 +84,11 @@ def generate_qa_list( break else: - # TODO: temperature and top_p/top_k should be configured later response = zhipuai.model_api.invoke( model="chatglm_6b", prompt=[{"role": "user", "content": content}], - top_p=0.7, - temperature=0.9, + top_p=top_p, + temperature=temperature, ) if response['success']: result = self.__format_response_to_qa_list(response) diff --git a/data-processing/data_manipulation/llm_prompt_template/bai_chuan_2_prompt.py b/data-processing/data_manipulation/llm_prompt_template/bai_chuan_2_prompt.py deleted file mode 100644 index 2294fcd7b..000000000 --- a/data-processing/data_manipulation/llm_prompt_template/bai_chuan_2_prompt.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2023 KubeAGI. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO: NOT used for now -def get_default_prompt_template(): - prompt_template = """ - {text} - - 将上述内容提出最多 25 个问题。给出每个问题的答案。每个问题必须有答案。 - """ - return prompt_template \ No newline at end of file diff --git a/data-processing/data_manipulation/llm_prompt_template/open_ai_prompt.py b/data-processing/data_manipulation/llm_prompt_template/llm_prompt.py similarity index 100% rename from data-processing/data_manipulation/llm_prompt_template/open_ai_prompt.py rename to data-processing/data_manipulation/llm_prompt_template/llm_prompt.py diff --git a/data-processing/data_manipulation/llm_prompt_template/zhi_pu_ai_prompt.py b/data-processing/data_manipulation/llm_prompt_template/zhi_pu_ai_prompt.py deleted file mode 100644 index d5a383672..000000000 --- a/data-processing/data_manipulation/llm_prompt_template/zhi_pu_ai_prompt.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2023 KubeAGI. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def get_default_prompt_template(): - prompt_template = """ - {text} - - 请将上述内容按照问答的方式,提出不超过 25 个问题,并给出每个问题的答案,每个问题必须有 Q 和对应的 A,并严格按照以下方式展示: Q1: 问题。\n A1: 答案。\n Q2: 问题 \n A2: 答案\n 注意,尽可能多的提出问题,但是 Q 不要重复,也不要出现只有 Q 没有 A 的情况。 - """ - - return prompt_template \ No newline at end of file diff --git a/data-processing/data_manipulation/server.py b/data-processing/data_manipulation/server.py index c69b2e781..fa423ad22 100644 --- a/data-processing/data_manipulation/server.py +++ b/data-processing/data_manipulation/server.py @@ -73,6 +73,24 @@ async def shutdown_web_server(app, loop): app.blueprint(data_process_controller.data_process) +@app.route('test_langchain', methods=['POST']) +async def test_langchain(request): + from langchain.chat_models import ChatOpenAI + + llm = ChatOpenAI( + openai_api_key='xx', + base_url='xx', + model='xx', + temperature=0.8, + top_k=0.1, + top_p=0.1 + ) + + return json({ + 'status': 200, + 'message': '', + 'data': '' + }) def _create_database_connection(): """Create a database connection.""" diff --git a/data-processing/data_manipulation/service/data_process_service.py b/data-processing/data_manipulation/service/data_process_service.py index 972302c94..002364eaf 100644 --- a/data-processing/data_manipulation/service/data_process_service.py +++ b/data-processing/data_manipulation/service/data_process_service.py @@ -21,10 +21,12 @@ from common import log_tag_const from data_store_process import minio_store_process from database_operate import (data_process_db_operate, - data_process_detail_db_operate) + data_process_detail_db_operate, + data_process_document_db_operate) from kube import dataset_cr from parallel import thread_parallel from utils import date_time_utils +from kube import model_cr logger = logging.getLogger(__name__) @@ -156,7 +158,12 @@ def info_by_id( process_cofig_map = _convert_config_info_to_map(data.get('data_process_config_info')) config_map_for_result = {} - _set_basic_info_for_config_map_for_result(config_map_for_result, process_cofig_map) + _set_basic_info_for_config_map_for_result( + config_map_for_result, + process_cofig_map, + task_id=id, + conn_pool=pool + ) _set_children_info_for_config_map_for_result( config_map_for_result, @@ -267,14 +274,16 @@ def _convert_config_info_to_map(config_info_list): """ result = {} for item in config_info_list: - result[item['type']] = 1 + result[item['type']] = item return result def _set_basic_info_for_config_map_for_result( from_result, - process_cofig_map + process_cofig_map, + task_id, + conn_pool ): """Set basic info for the config map for result. @@ -287,7 +296,10 @@ def _set_basic_info_for_config_map_for_result( from_result['chunk_processing'] = { 'name': 'chunk_processing', 'description': '拆分处理', - 'status': 'succeed', + 'status': _get_qa_split_status( + task_id=task_id, + conn_pool=conn_pool + ), 'children': [] } @@ -302,7 +314,7 @@ def _set_basic_info_for_config_map_for_result( from_result['clean'] = { 'name': 'clean', 'description': '异常清洗配置', - 'status': 'succeed', + 'status': 'success', 'children': [] } @@ -314,7 +326,7 @@ def _set_basic_info_for_config_map_for_result( from_result['privacy'] = { 'name': 'privacy', 'description': '数据隐私处理', - 'status': 'succeed', + 'status': 'success', 'children': [] } @@ -338,10 +350,17 @@ def _set_children_info_for_config_map_for_result( 'name': 'qa_split', 'enable': 'true', 'zh_name': 'QA拆分', - 'description': '根据文件中的文章与图表标题,自动将文件做 QA 拆分处理。', + 'description': '根据文件中的文档内容,自动将文件做 QA 拆分处理。', + 'llm_config': _get_llm_config( + qa_split_config = process_cofig_map.get('qa_split') + ), 'preview': _get_qa_list_preview( task_id=task_id, conn_pool=conn_pool + ), + 'file_progress': _get_file_progress( + task_id=task_id, + conn_pool=conn_pool ) }) @@ -568,5 +587,81 @@ def _get_qa_list_preview( return qa_list_preview +def _get_file_progress( + task_id, + conn_pool +): + """Get file progress. + + task_id: task id; + conn_pool: database connection pool + """ + # Get the detail info from the database. + detail_info_params = { + 'task_id': task_id + } + list_file = data_process_document_db_operate.list_file_by_task_id( + detail_info_params, + pool=conn_pool + ) + + return list_file.get('data') + +def _get_qa_split_status( + task_id, + conn_pool +): + """Get file progress. + + task_id: task id; + conn_pool: database connection pool + """ + # Get the detail info from the database. + status = 'doing' + detail_info_params = { + 'task_id': task_id + } + list_file = data_process_document_db_operate.list_file_by_task_id( + detail_info_params, + pool=conn_pool + ) + + if list_file.get('status') != 200 or len(list_file.get('data')) == 0: + return 'fail' + + file_dict = list_file.get('data') + + # 当所有文件状态都为success,则status为success + all_success = all(item['status'] == 'success' for item in file_dict) + if all_success: + return 'success' + + # 当所有文件状态都为not_start,则status为not_start + all_success = all(item['status'] == 'not_start' for item in file_dict) + if all_success: + return 'not_start' + + # 只要有一个文件状态为fail,则status为fail + status_fail = any(item['status'] == 'fail' for item in file_dict) + if status_fail: + return 'fail' + return status + +def _get_llm_config( + qa_split_config +): + llm_config = qa_split_config.get('llm_config') + + # llms cr 中模型相关信息 + llm_spec_info = model_cr.get_spec_for_llms_k8s_cr( + name=llm_config.get('name'), + namespace=llm_config.get('namespace') + ) + + if llm_spec_info.get('data').get('provider').get('worker'): + llm_config['provider'] = 'worker' + else: + llm_config['provider'] = '3rd_party' + return llm_config diff --git a/data-processing/data_manipulation/transform/text/support_type.py b/data-processing/data_manipulation/transform/text/support_type.py index 464a344aa..a20e74aaf 100644 --- a/data-processing/data_manipulation/transform/text/support_type.py +++ b/data-processing/data_manipulation/transform/text/support_type.py @@ -25,7 +25,7 @@ def get_default_support_types(): 'name': 'qa_split', 'enable': 'true', 'zh_name': 'QA拆分', - 'description': '根据文件中的文章与图表标题,自动将文件做 QA 拆分处理。' + 'description': '根据文件中的文档内容,自动将文件做 QA 拆分处理。' }, { 'name': 'document_chunk', diff --git a/data-processing/db-scripts/init-database-schema.sql b/data-processing/db-scripts/init-database-schema.sql index 651e7aff7..16e50620f 100644 --- a/data-processing/db-scripts/init-database-schema.sql +++ b/data-processing/db-scripts/init-database-schema.sql @@ -96,3 +96,39 @@ COMMENT ON COLUMN public.data_process_task_question_answer.update_user IS '更新用户'; COMMENT ON COLUMN public.data_process_task_question_answer.update_program IS '更新程序'; + CREATE TABLE IF NOT EXISTS public.data_process_task_document + ( + id character varying(64) COLLATE pg_catalog."default" NOT NULL, + file_name character varying(512) COLLATE pg_catalog."default", + status character varying(64) COLLATE pg_catalog."default", + process_info text COLLATE pg_catalog."default", + start_time character varying(32) COLLATE pg_catalog."default", + end_time character varying(32) COLLATE pg_catalog."default", + progress character varying(32) COLLATE pg_catalog."default", + chunk_size character varying(64) COLLATE pg_catalog."default", + task_id character varying(32) COLLATE pg_catalog."default", + create_datetime character varying(32) COLLATE pg_catalog."default", + create_user character varying(32) COLLATE pg_catalog."default", + create_program character varying(64) COLLATE pg_catalog."default", + update_datetime character varying(32) COLLATE pg_catalog."default", + update_user character varying(32) COLLATE pg_catalog."default", + update_program character varying(32) COLLATE pg_catalog."default", + CONSTRAINT data_process_task_document_pkey PRIMARY KEY (id) + ) + + COMMENT ON TABLE public.data_process_task_document IS '数据处理任务文档'; + COMMENT ON COLUMN public.data_process_task_document.id IS '主键'; + COMMENT ON COLUMN public.data_process_task_document.file_name IS '文件名称'; + COMMENT ON COLUMN public.data_process_task_document.status IS '状态 如not_start, doing, success, fail'; + COMMENT ON COLUMN public.data_process_task_document.process_info IS '处理信息'; + COMMENT ON COLUMN public.data_process_task_document.start_time IS '开始时间'; + COMMENT ON COLUMN public.data_process_task_document.end_time IS '结束时间'; + COMMENT ON COLUMN public.data_process_task_document.progress IS '进度'; + COMMENT ON COLUMN public.data_process_task_document.chunk_size IS '文本拆分数量'; + COMMENT ON COLUMN public.data_process_task_document.task_id IS '任务id'; + COMMENT ON COLUMN public.data_process_task_document.create_datetime IS '创建时间'; + COMMENT ON COLUMN public.data_process_task_document.create_user IS '创建用户'; + COMMENT ON COLUMN public.data_process_task_document.create_program IS '创建程序'; + COMMENT ON COLUMN public.data_process_task_document.update_datetime IS '更新时间'; + COMMENT ON COLUMN public.data_process_task_document.update_user IS '更新用户'; + COMMENT ON COLUMN public.data_process_task_document.update_program IS '更新程序'; diff --git a/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml b/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml index 24c977c7a..5b724bf02 100644 --- a/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml +++ b/deploy/charts/arcadia/templates/pg-init-data-configmap.yaml @@ -99,6 +99,43 @@ data: COMMENT ON COLUMN public.data_process_task_question_answer.update_user IS '更新用户'; COMMENT ON COLUMN public.data_process_task_question_answer.update_program IS '更新程序'; + CREATE TABLE IF NOT EXISTS public.data_process_task_document + ( + id character varying(64) COLLATE pg_catalog."default" NOT NULL, + file_name character varying(512) COLLATE pg_catalog."default", + status character varying(64) COLLATE pg_catalog."default", + process_info text COLLATE pg_catalog."default", + start_time character varying(32) COLLATE pg_catalog."default", + end_time character varying(32) COLLATE pg_catalog."default", + progress character varying(32) COLLATE pg_catalog."default", + chunk_size character varying(64) COLLATE pg_catalog."default", + task_id character varying(32) COLLATE pg_catalog."default", + create_datetime character varying(32) COLLATE pg_catalog."default", + create_user character varying(32) COLLATE pg_catalog."default", + create_program character varying(64) COLLATE pg_catalog."default", + update_datetime character varying(32) COLLATE pg_catalog."default", + update_user character varying(32) COLLATE pg_catalog."default", + update_program character varying(32) COLLATE pg_catalog."default", + CONSTRAINT data_process_task_document_pkey PRIMARY KEY (id) + ) + + COMMENT ON TABLE public.data_process_task_document IS '数据处理任务文档'; + COMMENT ON COLUMN public.data_process_task_document.id IS '主键'; + COMMENT ON COLUMN public.data_process_task_document.file_name IS '文件名称'; + COMMENT ON COLUMN public.data_process_task_document.status IS '状态 如not_start, doing, success, fail'; + COMMENT ON COLUMN public.data_process_task_document.process_info IS '处理信息'; + COMMENT ON COLUMN public.data_process_task_document.start_time IS '开始时间'; + COMMENT ON COLUMN public.data_process_task_document.end_time IS '结束时间'; + COMMENT ON COLUMN public.data_process_task_document.progress IS '进度'; + COMMENT ON COLUMN public.data_process_task_document.chunk_size IS '文本拆分数量'; + COMMENT ON COLUMN public.data_process_task_document.task_id IS '任务id'; + COMMENT ON COLUMN public.data_process_task_document.create_datetime IS '创建时间'; + COMMENT ON COLUMN public.data_process_task_document.create_user IS '创建用户'; + COMMENT ON COLUMN public.data_process_task_document.create_program IS '创建程序'; + COMMENT ON COLUMN public.data_process_task_document.update_datetime IS '更新时间'; + COMMENT ON COLUMN public.data_process_task_document.update_user IS '更新用户'; + COMMENT ON COLUMN public.data_process_task_document.update_program IS '更新程序'; + kind: ConfigMap metadata: name: pg-init-data