From 7f8cf97ef13b2608d7ef6203a944b79d47a33d05 Mon Sep 17 00:00:00 2001 From: bjwswang Date: Wed, 27 Mar 2024 08:23:12 +0000 Subject: [PATCH] feat: able to configure gpts categories Signed-off-by: bjwswang --- apiserver/graph/generated/generated.go | 384 +++++++++++++++++++- apiserver/graph/generated/models_gen.go | 11 +- apiserver/graph/impl/gpt.resolvers.go | 9 + apiserver/graph/schema/gpt.graphqls | 26 +- apiserver/pkg/gpt/gpt.go | 19 + apiserver/service/forward.go | 1 + deploy/charts/arcadia/Chart.yaml | 2 +- deploy/charts/arcadia/templates/config.yaml | 30 ++ deploy/charts/arcadia/values.yaml | 15 +- gqlgen.yaml | 2 + pkg/config/config_type.go | 34 +- pkg/config/gpts_config.go | 79 ++++ 12 files changed, 569 insertions(+), 43 deletions(-) create mode 100644 pkg/config/gpts_config.go diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index 199ad4669..d012ffe34 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -384,9 +384,16 @@ type ComplexityRoot struct { Prologue func(childComplexity int) int } + GPTCategory struct { + ID func(childComplexity int) int + Name func(childComplexity int) int + NameEn func(childComplexity int) int + } + GPTQuery struct { - GetGpt func(childComplexity int, name string) int - ListGpt func(childComplexity int, input ListGPTInput) int + GetGpt func(childComplexity int, name string) int + ListGPTCategory func(childComplexity int) int + ListGpt func(childComplexity int, input ListGPTInput) int } KnowledgeBase struct { @@ -857,6 +864,7 @@ type EmbedderQueryResolver interface { type GPTQueryResolver interface { GetGpt(ctx context.Context, obj *GPTQuery, name string) (*Gpt, error) ListGpt(ctx context.Context, obj *GPTQuery, input ListGPTInput) (*PaginatedResult, error) + ListGPTCategory(ctx context.Context, obj *GPTQuery) ([]*GPTCategory, error) } type KnowledgeBaseMutationResolver interface { CreateKnowledgeBase(ctx context.Context, obj *KnowledgeBaseMutation, input CreateKnowledgeBaseInput) (*KnowledgeBase, error) @@ -2611,6 +2619,27 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.GPT.Prologue(childComplexity), true + case "GPTCategory.id": + if e.complexity.GPTCategory.ID == nil { + break + } + + return e.complexity.GPTCategory.ID(childComplexity), true + + case "GPTCategory.name": + if e.complexity.GPTCategory.Name == nil { + break + } + + return e.complexity.GPTCategory.Name(childComplexity), true + + case "GPTCategory.nameEn": + if e.complexity.GPTCategory.NameEn == nil { + break + } + + return e.complexity.GPTCategory.NameEn(childComplexity), true + case "GPTQuery.getGPT": if e.complexity.GPTQuery.GetGpt == nil { break @@ -2623,6 +2652,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.GPTQuery.GetGpt(childComplexity, args["name"].(string)), true + case "GPTQuery.listGPTCategory": + if e.complexity.GPTQuery.ListGPTCategory == nil { + break + } + + return e.complexity.GPTQuery.ListGPTCategory(childComplexity), true + case "GPTQuery.listGPT": if e.complexity.GPTQuery.ListGpt == nil { break @@ -6274,16 +6310,7 @@ type Tool { union PageNode = Datasource | Model | Embedder | KnowledgeBase | Dataset | VersionedDataset | F | Worker | ApplicationMetadata | LLM | ModelService | RayCluster | RAG | GPT | Node `, BuiltIn: false}, - {Name: "../schema/gpt.graphqls", Input: `type GPTQuery { - getGPT(name: String!): GPT! - listGPT(input: ListGPTInput!): PaginatedResult! -} - -extend type Query{ - GPT: GPTQuery -} - -input ListGPTInput { + {Name: "../schema/gpt.graphqls", Input: `input ListGPTInput { """ category: gpt所属分类 @@ -6355,6 +6382,23 @@ type GPT { """ prologue: String } + +# GPTCategory in gpt store +type GPTCategory { + id: String! + name: String! + nameEn: String! +} + +type GPTQuery { + getGPT(name: String!): GPT! + listGPT(input: ListGPTInput!): PaginatedResult! + listGPTCategory: [GPTCategory]! +} + +extend type Query{ + GPT: GPTQuery +} `, BuiltIn: false}, {Name: "../schema/k8s.graphqls", Input: `type LabelSelectorRequirement { key: String @@ -19312,6 +19356,138 @@ func (ec *executionContext) fieldContext_GPT_prologue(ctx context.Context, field return fc, nil } +func (ec *executionContext) _GPTCategory_id(ctx context.Context, field graphql.CollectedField, obj *GPTCategory) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_GPTCategory_id(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ID, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_GPTCategory_id(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "GPTCategory", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _GPTCategory_name(ctx context.Context, field graphql.CollectedField, obj *GPTCategory) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_GPTCategory_name(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Name, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_GPTCategory_name(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "GPTCategory", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _GPTCategory_nameEn(ctx context.Context, field graphql.CollectedField, obj *GPTCategory) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_GPTCategory_nameEn(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.NameEn, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_GPTCategory_nameEn(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "GPTCategory", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _GPTQuery_getGPT(ctx context.Context, field graphql.CollectedField, obj *GPTQuery) (ret graphql.Marshaler) { fc, err := ec.fieldContext_GPTQuery_getGPT(ctx, field) if err != nil { @@ -19452,6 +19628,58 @@ func (ec *executionContext) fieldContext_GPTQuery_listGPT(ctx context.Context, f return fc, nil } +func (ec *executionContext) _GPTQuery_listGPTCategory(ctx context.Context, field graphql.CollectedField, obj *GPTQuery) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_GPTQuery_listGPTCategory(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.GPTQuery().ListGPTCategory(rctx, obj) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.([]*GPTCategory) + fc.Result = res + return ec.marshalNGPTCategory2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐGPTCategory(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_GPTQuery_listGPTCategory(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "GPTQuery", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_GPTCategory_id(ctx, field) + case "name": + return ec.fieldContext_GPTCategory_name(ctx, field) + case "nameEn": + return ec.fieldContext_GPTCategory_nameEn(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type GPTCategory", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _KnowledgeBase_id(ctx context.Context, field graphql.CollectedField, obj *KnowledgeBase) (ret graphql.Marshaler) { fc, err := ec.fieldContext_KnowledgeBase_id(ctx, field) if err != nil { @@ -26563,6 +26791,8 @@ func (ec *executionContext) fieldContext_Query_GPT(ctx context.Context, field gr return ec.fieldContext_GPTQuery_getGPT(ctx, field) case "listGPT": return ec.fieldContext_GPTQuery_listGPT(ctx, field) + case "listGPTCategory": + return ec.fieldContext_GPTQuery_listGPTCategory(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type GPTQuery", field.Name) }, @@ -41800,6 +42030,55 @@ func (ec *executionContext) _GPT(ctx context.Context, sel ast.SelectionSet, obj return out } +var gPTCategoryImplementors = []string{"GPTCategory"} + +func (ec *executionContext) _GPTCategory(ctx context.Context, sel ast.SelectionSet, obj *GPTCategory) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, gPTCategoryImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("GPTCategory") + case "id": + out.Values[i] = ec._GPTCategory_id(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "name": + out.Values[i] = ec._GPTCategory_name(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "nameEn": + out.Values[i] = ec._GPTCategory_nameEn(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var gPTQueryImplementors = []string{"GPTQuery"} func (ec *executionContext) _GPTQuery(ctx context.Context, sel ast.SelectionSet, obj *GPTQuery) graphql.Marshaler { @@ -41882,6 +42161,42 @@ func (ec *executionContext) _GPTQuery(ctx context.Context, sel ast.SelectionSet, continue } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "listGPTCategory": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._GPTQuery_listGPTCategory(ctx, field, obj) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) default: panic("unknown field " + strconv.Quote(field.Name)) @@ -46488,6 +46803,44 @@ func (ec *executionContext) marshalNGPT2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋap return ec._GPT(ctx, sel, v) } +func (ec *executionContext) marshalNGPTCategory2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐGPTCategory(ctx context.Context, sel ast.SelectionSet, v []*GPTCategory) graphql.Marshaler { + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalOGPTCategory2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐGPTCategory(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + return ret +} + func (ec *executionContext) unmarshalNInt2int(ctx context.Context, v interface{}) (int, error) { res, err := graphql.UnmarshalInt(v) return res, graphql.ErrorOnPath(ctx, err) @@ -48048,6 +48401,13 @@ func (ec *executionContext) marshalOFloat2ᚖfloat64(ctx context.Context, sel as return graphql.WrapContextMarshaler(ctx, res) } +func (ec *executionContext) marshalOGPTCategory2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐGPTCategory(ctx context.Context, sel ast.SelectionSet, v *GPTCategory) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._GPTCategory(ctx, sel, v) +} + func (ec *executionContext) marshalOGPTQuery2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐGPTQuery(ctx context.Context, sel ast.SelectionSet, v *GPTQuery) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 875324daa..dfac978d2 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -857,9 +857,16 @@ type Gpt struct { func (Gpt) IsPageNode() {} +type GPTCategory struct { + ID string `json:"id"` + Name string `json:"name"` + NameEn string `json:"nameEn"` +} + type GPTQuery struct { - GetGpt Gpt `json:"getGPT"` - ListGpt PaginatedResult `json:"listGPT"` + GetGpt Gpt `json:"getGPT"` + ListGpt PaginatedResult `json:"listGPT"` + ListGPTCategory []*GPTCategory `json:"listGPTCategory"` } // 知识库 diff --git a/apiserver/graph/impl/gpt.resolvers.go b/apiserver/graph/impl/gpt.resolvers.go index a1fd6822f..9b256e8b4 100644 --- a/apiserver/graph/impl/gpt.resolvers.go +++ b/apiserver/graph/impl/gpt.resolvers.go @@ -29,6 +29,15 @@ func (r *gPTQueryResolver) ListGpt(ctx context.Context, obj *generated.GPTQuery, return gpt.ListGPT(ctx, c, input) } +// ListGPTCategory is the resolver for the listGPTCategory field. +func (r *gPTQueryResolver) ListGPTCategory(ctx context.Context, obj *generated.GPTQuery) ([]*generated.GPTCategory, error) { + c, err := getAdminClient() + if err != nil { + return nil, err + } + return gpt.ListGPTCategory(ctx, c) +} + // Gpt is the resolver for the GPT field. func (r *queryResolver) Gpt(ctx context.Context) (*generated.GPTQuery, error) { return &generated.GPTQuery{}, nil diff --git a/apiserver/graph/schema/gpt.graphqls b/apiserver/graph/schema/gpt.graphqls index 1b18285ff..788ff7963 100644 --- a/apiserver/graph/schema/gpt.graphqls +++ b/apiserver/graph/schema/gpt.graphqls @@ -1,12 +1,3 @@ -type GPTQuery { - getGPT(name: String!): GPT! - listGPT(input: ListGPTInput!): PaginatedResult! -} - -extend type Query{ - GPT: GPTQuery -} - input ListGPTInput { """ @@ -79,3 +70,20 @@ type GPT { """ prologue: String } + +# GPTCategory in gpt store +type GPTCategory { + id: String! + name: String! + nameEn: String! +} + +type GPTQuery { + getGPT(name: String!): GPT! + listGPT(input: ListGPTInput!): PaginatedResult! + listGPTCategory: [GPTCategory]! +} + +extend type Query{ + GPT: GPTQuery +} diff --git a/apiserver/pkg/gpt/gpt.go b/apiserver/pkg/gpt/gpt.go index f98192f38..8b0f40e4d 100644 --- a/apiserver/pkg/gpt/gpt.go +++ b/apiserver/pkg/gpt/gpt.go @@ -33,6 +33,7 @@ import ( "github.com/kubeagi/arcadia/apiserver/pkg/chat" "github.com/kubeagi/arcadia/apiserver/pkg/chat/storage" "github.com/kubeagi/arcadia/apiserver/pkg/common" + "github.com/kubeagi/arcadia/pkg/config" ) var ( @@ -92,6 +93,7 @@ func GetGPT(ctx context.Context, c client.Client, name string) (*generated.Gpt, return app2gpt(app, c) } +// ListGPT list all gpt func ListGPT(ctx context.Context, c client.Client, input generated.ListGPTInput) (*generated.PaginatedResult, error) { keyword := pointer.StringDeref(input.Keyword, "") category := pointer.StringDeref(input.Category, "") @@ -121,3 +123,20 @@ func ListGPT(ctx context.Context, c client.Client, input generated.ListGPTInput) return app2gpt(app, c) }, filter...) } + +// ListGPTCategory list all categories +func ListGPTCategory(ctx context.Context, c client.Client) ([]*generated.GPTCategory, error) { + categories, err := config.GetGPTsCategories(ctx, c) + if err != nil { + return nil, err + } + resp := make([]*generated.GPTCategory, len(categories)) + for i := range categories { + resp[i] = &generated.GPTCategory{ + Name: categories[i].Name, + NameEn: categories[i].NameEn, + ID: categories[i].ID, + } + } + return resp, nil +} diff --git a/apiserver/service/forward.go b/apiserver/service/forward.go index 4ec598b3b..9bc0ca870 100644 --- a/apiserver/service/forward.go +++ b/apiserver/service/forward.go @@ -45,6 +45,7 @@ const ( ) type ( + // FrowarAPI is the forward api handler which forward requests to other services FrowarAPI struct{} SummaryResp struct { Summary string `json:"summary"` diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 4a054eed5..b401ef1ba 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(Also a KubeBB Component) for KubeAGI Arcadia type: application -version: 0.3.22 +version: 0.3.23 appVersion: "0.2.1" keywords: diff --git a/deploy/charts/arcadia/templates/config.yaml b/deploy/charts/arcadia/templates/config.yaml index d8f83c9e6..d4a0c5cc9 100644 --- a/deploy/charts/arcadia/templates/config.yaml +++ b/deploy/charts/arcadia/templates/config.yaml @@ -69,6 +69,36 @@ data: user: {{ .Values.postgresql.global.postgresql.auth.username }} password: {{ .Values.postgresql.global.postgresql.auth.password }} database: {{ .Values.postgresql.global.postgresql.auth.database }} + # configurations for gpts + gptsConfig: | + categories: + - id: 1 + name: "通用对话" + nameEn: "General Conversation" + - id: 2 + name: "工作" + nameEn: "Work" + - id: 3 + name: "学习" + nameEn: "Learning" + - id: 4 + name: "效率" + nameEn: "Efficiency" + - id: 4 + name: "人物扮演" + nameEn: "Character Play" + - id: 5 + name: "游戏" + nameEn: "Game" + - id: 6 + name: "生活" + nameEn: "Life" + - id: 7 + name: "情感" + nameEn: "Emotion" + - id: 8 + name: "动漫" + nameEn: "Anime" {{- end }} kind: ConfigMap metadata: diff --git a/deploy/charts/arcadia/values.yaml b/deploy/charts/arcadia/values.yaml index 71f2eb25b..f5b1a600e 100644 --- a/deploy/charts/arcadia/values.yaml +++ b/deploy/charts/arcadia/values.yaml @@ -195,12 +195,9 @@ chromadb: ray: # clusters provided by ray # For more information on cluster configurations,please refer to http://kubeagi.k8s.com.cn/docs/Configuration/DistributedInference/run-inference-using-ray - clusters: - # cluster1 comes from https://github.com/kubeagi/arcadia/blob/main/config/samples/ray.io_v1_raycluster.yaml - - name: cluster1 - headAddress: raycluster-kuberay-head-svc.kuberay-system.svc:6379 - pythonVersion: 3.9.18 - dashboardHost: raycluster-kuberay-head-svc.kuberay-system.svc:8265 - -rerank: - enabled: true + clusters: {} + # # cluster1 comes from https://github.com/kubeagi/arcadia/blob/main/config/samples/ray.io_v1_raycluster.yaml + # - name: cluster1 + # headAddress: raycluster-kuberay-head-svc.kuberay-system.svc:6379 + # pythonVersion: 3.9.18 + # dashboardHost: raycluster-kuberay-head-svc.kuberay-system.svc:8265 \ No newline at end of file diff --git a/gqlgen.yaml b/gqlgen.yaml index 126e4825a..bc33b9f88 100644 --- a/gqlgen.yaml +++ b/gqlgen.yaml @@ -298,6 +298,8 @@ models: resolver: true listGPT: resolver: true + listGPTCategory: + resolver: true NodeQuery: fields: listNodes: diff --git a/pkg/config/config_type.go b/pkg/config/config_type.go index 1c16651b4..e54103e73 100644 --- a/pkg/config/config_type.go +++ b/pkg/config/config_type.go @@ -33,30 +33,41 @@ type Config struct { // Gateway to access LLM api services Gateway *Gateway `json:"gateway,omitempty"` - // Embedder specifies the default embedder for Arcadia to generate embeddings - Embedder *arcadiav1alpha1.TypedObjectReference `json:"embedder,omitempty"` - - // VectorStore to access VectorStore api services - VectorStore *arcadiav1alpha1.TypedObjectReference `json:"vectorStore,omitempty"` - - // Streamlit to get the Streamlit configuration - Streamlit *Streamlit `json:"streamlit,omitempty"` + // EmbeddingSuite here represents the system embedding service provided by the system + EmbeddingSuite // Resource pool managed by Ray cluster RayClusters []RayCluster `json:"rayClusters,omitempty"` // the default rerank model Rerank *arcadiav1alpha1.TypedObjectReference `json:"rerank,omitempty"` + + // Streamlit to get the Streamlit configuration + // Deprecated: this field no longer maintained + Streamlit *Streamlit `json:"streamlit,omitempty"` +} + +// EmbeddingSuite contains everything required to provide embedding service +type EmbeddingSuite struct { + // Embedder specifies the default embedder for Arcadia to generate embeddings + Embedder *arcadiav1alpha1.TypedObjectReference `json:"embedder,omitempty"` + + // VectorStore to access VectorStore api services + VectorStore *arcadiav1alpha1.TypedObjectReference `json:"vectorStore,omitempty"` } // Gateway defines the way to access llm apis host by Arcadia type Gateway struct { + // ExternalAPIServer is the api(LLM/Embedding) server address that can be accessed from internet ExternalAPIServer string `json:"externalApiServer,omitempty"` - APIServer string `json:"apiServer,omitempty"` - Controller string `json:"controller,omitempty"` + // APIServer is api(LLM/Embedding) server which can be accessed within platform + APIServer string `json:"apiServer,omitempty"` + // Controller is the server address which is responsible for llm/embedding service registration + Controller string `json:"controller,omitempty"` } // Streamlit defines the configuration of streamlit app +// Deprecated: no longer maintained type Streamlit struct { Image string `json:"image"` IngressClassName string `json:"ingressClassName"` @@ -78,6 +89,7 @@ type RayCluster struct { RayVersion string `json:"rayVersion,omitempty"` } +// String format raycluster into string func (rayCluster RayCluster) String() string { return fmt.Sprintf("Name:%s HeadAddress: %s DashboardHost:%s PythonVersion:%s RayVersion: %s", rayCluster.Name, rayCluster.HeadAddress, rayCluster.DashboardHost, rayCluster.PythonVersion, rayCluster.RayVersion) } @@ -91,6 +103,7 @@ func (rayCluster RayCluster) GetRayVersion() string { return rayCluster.RayVersion } +// GetPythonVersion in ray cluster func (rayCluster RayCluster) GetPythonVersion() string { // Default python version is 3.9.5 if rayCluster.PythonVersion == "" { @@ -99,6 +112,7 @@ func (rayCluster RayCluster) GetPythonVersion() string { return rayCluster.PythonVersion } +// DefaultRayCluster which can be used for vllm worker as local ray cluster func DefaultRayCluster() RayCluster { return RayCluster{ Name: "default", diff --git a/pkg/config/gpts_config.go b/pkg/config/gpts_config.go new file mode 100644 index 000000000..eca985e6a --- /dev/null +++ b/pkg/config/gpts_config.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/yaml" + "k8s.io/utils/env" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeagi/arcadia/pkg/utils" +) + +var ( + ErrNoGPTsConfig = fmt.Errorf("gpts config in configmap is empty") + ErrNoGPTsConfigCategory = fmt.Errorf("gpts config Categories in comfigmap is not found") +) + +// GPTsConfig is the configurations for GPT Store +type GPTsConfig struct { + Categories []Category `json:"categories,omitempty"` +} + +// Category in gpt store +type Category struct { + ID string `json:"id"` + Name string `json:"name"` + NameEn string `json:"nameEn,omitempty"` +} + +// GetGPTsConfig gets the gpts configurations +func GetGPTsConfig(ctx context.Context, c client.Client) (gptsConfig *GPTsConfig, err error) { + cmName := env.GetString(EnvConfigKey, EnvConfigDefaultValue) + if cmName == "" { + return nil, ErrNoConfigEnv + } + cmNamespace := utils.GetCurrentNamespace() + cm := &corev1.ConfigMap{} + if err = c.Get(ctx, client.ObjectKey{Name: cmName, Namespace: cmNamespace}, cm); err != nil { + return nil, err + } + value, ok := cm.Data["gptsConfig"] + if !ok || len(value) == 0 { + return nil, ErrNoConfig + } + if err = yaml.Unmarshal([]byte(value), &gptsConfig); err != nil { + return nil, err + } + return gptsConfig, nil +} + +// GetGPTsCategories gets the gpts Categories +func GetGPTsCategories(ctx context.Context, c client.Client) (categories []Category, err error) { + gptsConfig, err := GetGPTsConfig(ctx, c) + if err != nil { + return nil, err + } + if gptsConfig.Categories == nil || len(gptsConfig.Categories) == 0 { + return nil, ErrNoGPTsConfigCategory + } + return gptsConfig.Categories, nil +}