From 3aeb6255d21b9d31c56bc89ddcd9d23bf60095db Mon Sep 17 00:00:00 2001 From: 0xff-dev Date: Thu, 21 Dec 2023 22:49:36 +0800 Subject: [PATCH] feat: list llms and embedders --- apiserver/graph/generated/generated.go | 295 +++++++++----- apiserver/graph/generated/models_gen.go | 9 +- .../graph/impl/modelservice.resolvers.go | 27 +- apiserver/graph/schema/modelservice.gql | 115 ++++++ apiserver/graph/schema/modelservice.graphqls | 16 +- apiserver/pkg/embedder/embedder.go | 1 + apiserver/pkg/modelservice/modelservice.go | 371 ++++++++++++++++++ apiserver/pkg/modelservice/sort.go | 72 ++++ apiserver/pkg/modelservice/sort_test.go | 76 ++++ gqlgen.yaml | 7 +- 10 files changed, 886 insertions(+), 103 deletions(-) create mode 100644 apiserver/pkg/modelservice/sort.go create mode 100644 apiserver/pkg/modelservice/sort_test.go diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index bcd1c19ad..fee3a6443 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -56,6 +56,7 @@ type ResolverRoot interface { ModelMutation() ModelMutationResolver ModelQuery() ModelQueryResolver ModelServiceMutation() ModelServiceMutationResolver + ModelServiceQuery() ModelServiceQueryResolver Mutation() MutationResolver Query() QueryResolver VersionedDataset() VersionedDatasetResolver @@ -432,6 +433,7 @@ type ComplexityRoot struct { LlmResource func(childComplexity int) int Name func(childComplexity int) int Namespace func(childComplexity int) int + Resource func(childComplexity int) int Types func(childComplexity int) int UpdateTimestamp func(childComplexity int) int } @@ -443,7 +445,7 @@ type ComplexityRoot struct { } ModelServiceQuery struct { - GetModelService func(childComplexity int, name string, apiType string) int + GetModelService func(childComplexity int, name string, namespace string, apiType string) int ListModelServices func(childComplexity int, input *ListModelService) int } @@ -671,6 +673,10 @@ type ModelServiceMutationResolver interface { UpdateModelService(ctx context.Context, obj *ModelServiceMutation, input *UpdateModelServiceInput) (*ModelService, error) DeleteModelService(ctx context.Context, obj *ModelServiceMutation, input *DeleteCommonInput) (*string, error) } +type ModelServiceQueryResolver interface { + GetModelService(ctx context.Context, obj *ModelServiceQuery, name string, namespace string, apiType string) (*ModelService, error) + ListModelServices(ctx context.Context, obj *ModelServiceQuery, input *ListModelService) (*PaginatedResult, error) +} type MutationResolver interface { Hello(ctx context.Context, name string) (string, error) Application(ctx context.Context) (*ApplicationMutation, error) @@ -2659,6 +2665,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.Namespace(childComplexity), true + case "ModelService.resource": + if e.complexity.ModelService.Resource == nil { + break + } + + return e.complexity.ModelService.Resource(childComplexity), true + case "ModelService.types": if e.complexity.ModelService.Types == nil { break @@ -2719,7 +2732,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.ModelServiceQuery.GetModelService(childComplexity, args["name"].(string), args["apiType"].(string)), true + return e.complexity.ModelServiceQuery.GetModelService(childComplexity, args["name"].(string), args["namespace"].(string), args["apiType"].(string)), true case "ModelServiceQuery.listModelServices": if e.complexity.ModelServiceQuery.ListModelServices == nil { @@ -5209,6 +5222,11 @@ extend type Query { """ llmResource: LLM embedderResource: Embedder + + """ + 第三方的服务不会有这个字段, 只有内部的Worker创建的才会有这个字段。 + """ + resource: Resources } input CreateModelServiceInput { """模型服务资源名称(不可同名)""" @@ -5283,6 +5301,7 @@ input ListModelService { keyword: String namespace: String! page: Int + pageSize: Int """ all, llm, embedding @@ -5292,12 +5311,12 @@ input ListModelService { """ worker, 3rd """ - providerType: String! + providerType: String """ openai, zhipuai """ - apiType: String! + apiType: String } type ModelServiceMutation { @@ -5311,13 +5330,14 @@ extend type Mutation { } type ModelServiceQuery { - getModelService(name: String!, apiType: String!): ModelService - listModelServices(input: ListModelService): [ModelService] + getModelService(name: String!, namespace: String!, apiType: String!): ModelService + listModelServices(input: ListModelService): PaginatedResult! } extend type Query { ModelService: ModelServiceQuery -}`, BuiltIn: false}, +} +`, BuiltIn: false}, {Name: "../schema/versioned_dataset.graphqls", Input: `scalar Int64 """ VersionedDataset @@ -6497,14 +6517,23 @@ func (ec *executionContext) field_ModelServiceQuery_getModelService_args(ctx con } args["name"] = arg0 var arg1 string + if tmp, ok := rawArgs["namespace"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("namespace")) + arg1, err = ec.unmarshalNString2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["namespace"] = arg1 + var arg2 string if tmp, ok := rawArgs["apiType"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("apiType")) - arg1, err = ec.unmarshalNString2string(ctx, tmp) + arg2, err = ec.unmarshalNString2string(ctx, tmp) if err != nil { return nil, err } } - args["apiType"] = arg1 + args["apiType"] = arg2 return args, nil } @@ -18720,6 +18749,55 @@ func (ec *executionContext) fieldContext_ModelService_embedderResource(ctx conte return fc, nil } +func (ec *executionContext) _ModelService_resource(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_resource(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Resource, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*Resources) + fc.Result = res + return ec.marshalOResources2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResources(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ModelService_resource(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ModelService", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "cpu": + return ec.fieldContext_Resources_cpu(ctx, field) + case "memory": + return ec.fieldContext_Resources_memory(ctx, field) + case "nvidiaGPU": + return ec.fieldContext_Resources_nvidiaGPU(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Resources", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _ModelServiceMutation_createModelService(ctx context.Context, field graphql.CollectedField, obj *ModelServiceMutation) (ret graphql.Marshaler) { fc, err := ec.fieldContext_ModelServiceMutation_createModelService(ctx, field) if err != nil { @@ -18787,6 +18865,8 @@ func (ec *executionContext) fieldContext_ModelServiceMutation_createModelService return ec.fieldContext_ModelService_llmResource(ctx, field) case "embedderResource": return ec.fieldContext_ModelService_embedderResource(ctx, field) + case "resource": + return ec.fieldContext_ModelService_resource(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -18872,6 +18952,8 @@ func (ec *executionContext) fieldContext_ModelServiceMutation_updateModelService return ec.fieldContext_ModelService_llmResource(ctx, field) case "embedderResource": return ec.fieldContext_ModelService_embedderResource(ctx, field) + case "resource": + return ec.fieldContext_ModelService_resource(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -18956,7 +19038,7 @@ func (ec *executionContext) _ModelServiceQuery_getModelService(ctx context.Conte }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.GetModelService, nil + return ec.resolvers.ModelServiceQuery().GetModelService(rctx, obj, fc.Args["name"].(string), fc.Args["namespace"].(string), fc.Args["apiType"].(string)) }) if err != nil { ec.Error(ctx, err) @@ -18974,8 +19056,8 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_getModelService(ctx c fc = &graphql.FieldContext{ Object: "ModelServiceQuery", Field: field, - IsMethod: false, - IsResolver: false, + IsMethod: true, + IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { case "id": @@ -19006,6 +19088,8 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_getModelService(ctx c return ec.fieldContext_ModelService_llmResource(ctx, field) case "embedderResource": return ec.fieldContext_ModelService_embedderResource(ctx, field) + case "resource": + return ec.fieldContext_ModelService_resource(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -19038,58 +19122,43 @@ func (ec *executionContext) _ModelServiceQuery_listModelServices(ctx context.Con }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.ListModelServices, nil + return ec.resolvers.ModelServiceQuery().ListModelServices(rctx, obj, fc.Args["input"].(*ListModelService)) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.([]*ModelService) + res := resTmp.(*PaginatedResult) fc.Result = res - return ec.marshalOModelService2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx, field.Selections, res) + return ec.marshalNPaginatedResult2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐPaginatedResult(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_ModelServiceQuery_listModelServices(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelServiceQuery", Field: field, - IsMethod: false, - IsResolver: false, + IsMethod: true, + IsResolver: true, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { switch field.Name { - case "id": - return ec.fieldContext_ModelService_id(ctx, field) - case "name": - return ec.fieldContext_ModelService_name(ctx, field) - case "namespace": - return ec.fieldContext_ModelService_namespace(ctx, field) - case "labels": - return ec.fieldContext_ModelService_labels(ctx, field) - case "annotations": - return ec.fieldContext_ModelService_annotations(ctx, field) - case "creator": - return ec.fieldContext_ModelService_creator(ctx, field) - case "displayName": - return ec.fieldContext_ModelService_displayName(ctx, field) - case "description": - return ec.fieldContext_ModelService_description(ctx, field) - case "types": - return ec.fieldContext_ModelService_types(ctx, field) - case "creationTimestamp": - return ec.fieldContext_ModelService_creationTimestamp(ctx, field) - case "updateTimestamp": - return ec.fieldContext_ModelService_updateTimestamp(ctx, field) - case "apiType": - return ec.fieldContext_ModelService_apiType(ctx, field) - case "llmResource": - return ec.fieldContext_ModelService_llmResource(ctx, field) - case "embedderResource": - return ec.fieldContext_ModelService_embedderResource(ctx, field) + case "hasNextPage": + return ec.fieldContext_PaginatedResult_hasNextPage(ctx, field) + case "nodes": + return ec.fieldContext_PaginatedResult_nodes(ctx, field) + case "page": + return ec.fieldContext_PaginatedResult_page(ctx, field) + case "pageSize": + return ec.fieldContext_PaginatedResult_pageSize(ctx, field) + case "totalCount": + return ec.fieldContext_PaginatedResult_totalCount(ctx, field) } - return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) + return nil, fmt.Errorf("no field named %q was found under type PaginatedResult", field.Name) }, } defer func() { @@ -27586,7 +27655,7 @@ func (ec *executionContext) unmarshalInputListModelService(ctx context.Context, asMap[k] = v } - fieldsInOrder := [...]string{"keyword", "namespace", "page", "modelType", "providerType", "apiType"} + fieldsInOrder := [...]string{"keyword", "namespace", "page", "pageSize", "modelType", "providerType", "apiType"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -27620,6 +27689,15 @@ func (ec *executionContext) unmarshalInputListModelService(ctx context.Context, return it, err } it.Page = data + case "pageSize": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("pageSize")) + data, err := ec.unmarshalOInt2ᚖint(ctx, v) + if err != nil { + return it, err + } + it.PageSize = data case "modelType": var err error @@ -27633,7 +27711,7 @@ func (ec *executionContext) unmarshalInputListModelService(ctx context.Context, var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("providerType")) - data, err := ec.unmarshalNString2string(ctx, v) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) if err != nil { return it, err } @@ -27642,7 +27720,7 @@ func (ec *executionContext) unmarshalInputListModelService(ctx context.Context, var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("apiType")) - data, err := ec.unmarshalNString2string(ctx, v) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) if err != nil { return it, err } @@ -32448,6 +32526,8 @@ func (ec *executionContext) _ModelService(ctx context.Context, sel ast.Selection out.Values[i] = ec._ModelService_llmResource(ctx, field, obj) case "embedderResource": out.Values[i] = ec._ModelService_embedderResource(ctx, field, obj) + case "resource": + out.Values[i] = ec._ModelService_resource(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -32622,9 +32702,74 @@ func (ec *executionContext) _ModelServiceQuery(ctx context.Context, sel ast.Sele case "__typename": out.Values[i] = graphql.MarshalString("ModelServiceQuery") case "getModelService": - out.Values[i] = ec._ModelServiceQuery_getModelService(ctx, field, obj) + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._ModelServiceQuery_getModelService(ctx, field, obj) + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) case "listModelServices": - out.Values[i] = ec._ModelServiceQuery_listModelServices(ctx, field, obj) + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._ModelServiceQuery_listModelServices(ctx, field, obj) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -35878,47 +36023,6 @@ func (ec *executionContext) marshalOModelQuery2ᚖgithubᚗcomᚋkubeagiᚋarcad return ec._ModelQuery(ctx, sel, v) } -func (ec *executionContext) marshalOModelService2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx context.Context, sel ast.SelectionSet, v []*ModelService) graphql.Marshaler { - if v == nil { - return graphql.Null - } - ret := make(graphql.Array, len(v)) - var wg sync.WaitGroup - isLen1 := len(v) == 1 - if !isLen1 { - wg.Add(len(v)) - } - for i := range v { - i := i - fc := &graphql.FieldContext{ - Index: &i, - Result: &v[i], - } - ctx := graphql.WithFieldContext(ctx, fc) - f := func(i int) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = nil - } - }() - if !isLen1 { - defer wg.Done() - } - ret[i] = ec.marshalOModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx, sel, v[i]) - } - if isLen1 { - f(i) - } else { - go f(i) - } - - } - wg.Wait() - - return ret -} - func (ec *executionContext) marshalOModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx context.Context, sel ast.SelectionSet, v *ModelService) graphql.Marshaler { if v == nil { return graphql.Null @@ -36009,6 +36113,13 @@ func (ec *executionContext) marshalOPaginatedDataProcessItem2ᚖgithubᚗcomᚋk return ec._PaginatedDataProcessItem(ctx, sel, v) } +func (ec *executionContext) marshalOResources2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResources(ctx context.Context, sel ast.SelectionSet, v *Resources) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._Resources(ctx, sel, v) +} + func (ec *executionContext) unmarshalOResourcesInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResourcesInput(ctx context.Context, v interface{}) (*ResourcesInput, error) { if v == nil { return nil, nil diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 618491c50..ae39189fd 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -892,12 +892,13 @@ type ListModelService struct { Keyword *string `json:"keyword,omitempty"` Namespace string `json:"namespace"` Page *int `json:"page,omitempty"` + PageSize *int `json:"pageSize,omitempty"` // all, llm, embedding ModelType string `json:"modelType"` // worker, 3rd - ProviderType string `json:"providerType"` + ProviderType *string `json:"providerType,omitempty"` // openai, zhipuai - APIType string `json:"apiType"` + APIType *string `json:"apiType,omitempty"` } type ListVersionedDatasetInput struct { @@ -1007,6 +1008,8 @@ type ModelService struct { // 模型对应的 LLM 及 embedder CR 资源 LlmResource *Llm `json:"llmResource,omitempty"` EmbedderResource *Embedder `json:"embedderResource,omitempty"` + // 第三方的服务不会有这个字段, 只有内部的Worker创建的才会有这个字段。 + Resource *Resources `json:"resource,omitempty"` } func (ModelService) IsPageNode() {} @@ -1019,7 +1022,7 @@ type ModelServiceMutation struct { type ModelServiceQuery struct { GetModelService *ModelService `json:"getModelService,omitempty"` - ListModelServices []*ModelService `json:"listModelServices,omitempty"` + ListModelServices PaginatedResult `json:"listModelServices"` } // 对象存储的使用信息 diff --git a/apiserver/graph/impl/modelservice.resolvers.go b/apiserver/graph/impl/modelservice.resolvers.go index 28198b9fa..338c4d688 100644 --- a/apiserver/graph/impl/modelservice.resolvers.go +++ b/apiserver/graph/impl/modelservice.resolvers.go @@ -6,7 +6,6 @@ package impl import ( "context" - "fmt" "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/modelservice" @@ -39,6 +38,24 @@ func (r *modelServiceMutationResolver) DeleteModelService(ctx context.Context, o return modelservice.DeleteModelService(ctx, c, input) } +// GetModelService is the resolver for the getModelService field. +func (r *modelServiceQueryResolver) GetModelService(ctx context.Context, obj *generated.ModelServiceQuery, name string, namespace string, apiType string) (*generated.ModelService, error) { + c, err := getClientFromCtx(ctx) + if err != nil { + return nil, err + } + return modelservice.GetModelService(ctx, c, name, namespace, apiType) +} + +// ListModelServices is the resolver for the listModelServices field. +func (r *modelServiceQueryResolver) ListModelServices(ctx context.Context, obj *generated.ModelServiceQuery, input *generated.ListModelService) (*generated.PaginatedResult, error) { + c, err := getClientFromCtx(ctx) + if err != nil { + return nil, err + } + return modelservice.ListModelServices(ctx, c, input) +} + // ModelService is the resolver for the ModelService field. func (r *mutationResolver) ModelService(ctx context.Context) (*generated.ModelServiceMutation, error) { return &generated.ModelServiceMutation{}, nil @@ -46,7 +63,7 @@ func (r *mutationResolver) ModelService(ctx context.Context) (*generated.ModelSe // ModelService is the resolver for the ModelService field. func (r *queryResolver) ModelService(ctx context.Context) (*generated.ModelServiceQuery, error) { - panic(fmt.Errorf("not implemented: ModelService - ModelService")) + return &generated.ModelServiceQuery{}, nil } // ModelServiceMutation returns generated.ModelServiceMutationResolver implementation. @@ -54,4 +71,10 @@ func (r *Resolver) ModelServiceMutation() generated.ModelServiceMutationResolver return &modelServiceMutationResolver{r} } +// ModelServiceQuery returns generated.ModelServiceQueryResolver implementation. +func (r *Resolver) ModelServiceQuery() generated.ModelServiceQueryResolver { + return &modelServiceQueryResolver{r} +} + type modelServiceMutationResolver struct{ *Resolver } +type modelServiceQueryResolver struct{ *Resolver } diff --git a/apiserver/graph/schema/modelservice.gql b/apiserver/graph/schema/modelservice.gql index dc5eb9972..4ae9eb8f3 100644 --- a/apiserver/graph/schema/modelservice.gql +++ b/apiserver/graph/schema/modelservice.gql @@ -40,4 +40,119 @@ mutation deleteModelServices($input: DeleteCommonInput) { ModelService { deleteModelService(input: $input) } +} + +query getModelService($name: String!, $namespace: String!, $apiType: String!) { + ModelService { + getModelService(name: $name, namespace: $namespace, apiType: $apiType) { + id + name + namespace + labels + annotations + creator + displayName + description + types + apiType + creationTimestamp + updateTimestamp + llmResource { + name + namespace + labels + annotations + displayName + description + baseUrl + models + provider + type + updateTimestamp + status + message + } + embedderResource{ + name + namespace + labels + annotations + displayName + description + baseUrl + models + provider + type + updateTimestamp + status + message + } + resource { + cpu + memory + nvidiaGPU + } + } + } +} + +query listModelServices($input: ListModelService) { + ModelService { + listModelServices(input: $input) { + totalCount + hasNextPage + nodes { + __typename + ... on ModelService { + id + name + namespace + labels + annotations + creator + displayName + description + types + apiType + creationTimestamp + updateTimestamp + llmResource { + name + namespace + labels + annotations + displayName + description + baseUrl + models + provider + type + updateTimestamp + status + message + } + embedderResource{ + name + namespace + labels + annotations + displayName + description + baseUrl + models + provider + type + updateTimestamp + status + message + } + resource { + cpu + memory + nvidiaGPU + } + } + } + } + } } \ No newline at end of file diff --git a/apiserver/graph/schema/modelservice.graphqls b/apiserver/graph/schema/modelservice.graphqls index 03832e248..f73f37fdd 100644 --- a/apiserver/graph/schema/modelservice.graphqls +++ b/apiserver/graph/schema/modelservice.graphqls @@ -27,6 +27,11 @@ type ModelService { """ llmResource: LLM embedderResource: Embedder + + """ + 第三方的服务不会有这个字段, 只有内部的Worker创建的才会有这个字段。 + """ + resource: Resources } input CreateModelServiceInput { """模型服务资源名称(不可同名)""" @@ -101,6 +106,7 @@ input ListModelService { keyword: String namespace: String! page: Int + pageSize: Int """ all, llm, embedding @@ -110,12 +116,12 @@ input ListModelService { """ worker, 3rd """ - providerType: String! + providerType: String """ openai, zhipuai """ - apiType: String! + apiType: String } type ModelServiceMutation { @@ -129,10 +135,10 @@ extend type Mutation { } type ModelServiceQuery { - getModelService(name: String!, apiType: String!): ModelService - listModelServices(input: ListModelService): [ModelService] + getModelService(name: String!, namespace: String!, apiType: String!): ModelService + listModelServices(input: ListModelService): PaginatedResult! } extend type Query { ModelService: ModelServiceQuery -} \ No newline at end of file +} diff --git a/apiserver/pkg/embedder/embedder.go b/apiserver/pkg/embedder/embedder.go index d610e5d3e..e7e4d37f4 100644 --- a/apiserver/pkg/embedder/embedder.go +++ b/apiserver/pkg/embedder/embedder.go @@ -211,6 +211,7 @@ func DeleteEmbedders(ctx context.Context, c dynamic.Interface, input *generated. } return nil, nil } + func ListEmbedders(ctx context.Context, c dynamic.Interface, input generated.ListCommonInput) (*generated.PaginatedResult, error) { keyword, labelSelector, fieldSelector := "", "", "" page, pageSize := 1, 10 diff --git a/apiserver/pkg/modelservice/modelservice.go b/apiserver/pkg/modelservice/modelservice.go index 877a74742..550e4dd0f 100644 --- a/apiserver/pkg/modelservice/modelservice.go +++ b/apiserver/pkg/modelservice/modelservice.go @@ -18,14 +18,17 @@ package modelservice import ( "context" + "fmt" "strings" "time" "k8s.io/client-go/dynamic" + "k8s.io/klog/v2" "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/embedder" "github.com/kubeagi/arcadia/apiserver/pkg/llm" + "github.com/kubeagi/arcadia/apiserver/pkg/worker" ) func CreateModelService(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (*generated.ModelService, error) { @@ -169,3 +172,371 @@ func DeleteModelService(ctx context.Context, c dynamic.Interface, input *generat } return nil, nil } + +var ( + fixedPage = 1 + + // because the data is to be paged, no parameters are provided, + // and the default return is 10. Getting modelserve needs to get all the llms and embedding, + // so a larger pageSize is provided here. + fixedPageSize = 100 +) + +const ( + modelTypeAll = "all" + modelTypeLLM = "llm" + modelTypeEmbedding = "embedding" +) + +func debugModelService(m *generated.ModelService) string { + id := "" + if m.ID != nil { + id = *m.ID + } + + creator := "" + if m.Creator != nil { + creator = *m.Creator + } + types := "" + if m.Types != nil { + types = *m.Types + } + return fmt.Sprintf("{id: %s, creator: %s, types: %s, apiType: %s, creationTimestamp: %s, updateTimestamp: %s}", + id, creator, types, *m.APIType, m.CreationTimestamp, m.UpdateTimestamp) +} + +func ListModelServices(ctx context.Context, c dynamic.Interface, input *generated.ListModelService) (*generated.PaginatedResult, error) { + var ( + llmList, embedderList, workerList *generated.PaginatedResult + err error + ) + query := generated.ListCommonInput{Page: &fixedPage, PageSize: &fixedPageSize, Namespace: input.Namespace, Keyword: input.Keyword} + + // list all llms + llmList, err = llm.ListLLMs(ctx, c, query) + if err != nil { + klog.Errorf("failed to list llm %s", err) + return nil, err + } + + // list all embedders + embedderList, err = embedder.ListEmbedders(ctx, c, query) + if err != nil { + klog.Errorf("failed to list embedder %s", err) + + return nil, err + } + + // list all workers + workerList, err = worker.ListWorkers(ctx, c, generated.ListWorkerInput{Namespace: input.Namespace, Page: &fixedPage, PageSize: &fixedPageSize}) + if err != nil { + klog.Errorf("failed to list worker %s", err) + return nil, err + } + + workerResource := make(map[string]*generated.Worker) + for _, item := range workerList.Nodes { + v := item.(*generated.Worker) + workerResource[v.Name] = v + } + + modelServiceList := make([]*generated.ModelService, 0) + intersec := make(map[string]*generated.ModelService) + klog.V(5).Infof("namespace: %s modetype: %s, providerType: %s, apiType: %s", + input.Namespace, input.ModelType, input.ProviderType, input.APIType) + + // The overall idea is to get the llm-filtered list first when not filtering on modelType or when you want to filter on llm. + // Or to filter embedder, first get the list of embedder. + if input.ModelType == "" || input.ModelType == modelTypeAll || input.ModelType == modelTypeLLM { + for idx := range llmList.Nodes { + v := llmList.Nodes[idx].(*generated.Llm) + klog.V(5).Infof("add llm modelservice: llm %s, type: %s, provider: %s. filter apiType: %s, filter providers: %s", + v.Name, *v.Type, *v.Provider, input.APIType, input.ProviderType) + if input.APIType != nil && *v.Type != *input.APIType { + continue + } + if input.ProviderType != nil && *v.Provider != *input.ProviderType { + continue + } + + ms := &generated.ModelService{ + Name: v.Name, + Namespace: input.Namespace, + Creator: v.Creator, + Description: v.Description, + Types: new(string), + CreationTimestamp: v.CreationTimestamp, + UpdateTimestamp: v.UpdateTimestamp, + APIType: v.Type, + LlmResource: v, + EmbedderResource: new(generated.Embedder), + Resource: new(generated.Resources), + } + *ms.Types = modelTypeLLM + intersec[v.Name] = ms + + modelServiceList = append(modelServiceList, ms) + klog.V(5).Infof("add llm modelservice: append only llm modelService to list: %s", debugModelService(ms)) + if r, ok := workerResource[v.Name]; ok { + klog.V(5).Infof(" set llm modelservice: %s resource: set modelservice resource: %+v", v.Name, r) + *ms.Resource = r.Resources + ms.ID = r.ID + } + } + } else { + for idx := range embedderList.Nodes { + v := embedderList.Nodes[idx].(*generated.Embedder) + klog.V(5).Infof("add embedder modelservice: embedder %s, type: %s, provider: %s. filter apiType: %s, filter providers: %s", + v.Name, *v.Type, *v.Provider, input.APIType, input.ProviderType) + + if input.APIType != nil && *v.Type != *input.APIType { + continue + } + if input.ProviderType != nil && *v.Provider != *input.ProviderType { + continue + } + + ms := &generated.ModelService{ + Name: v.Name, + Namespace: input.Namespace, + Creator: v.Creator, + Description: v.Description, + Types: new(string), + CreationTimestamp: v.CreationTimestamp, + UpdateTimestamp: v.UpdateTimestamp, + APIType: v.Type, + EmbedderResource: v, + LlmResource: new(generated.Llm), + Resource: new(generated.Resources), + } + *ms.Types = modelTypeEmbedding + + intersec[v.Name] = ms + modelServiceList = append(modelServiceList, ms) + klog.V(5).Infof("add embedder modelservice: append only embedder modelService to list: %s", debugModelService(ms)) + + if r, ok := workerResource[v.Name]; ok { + klog.V(5).Infof("set embedder modelservice: %s resource: %+v", v.Name, r) + *ms.Resource = r.Resources + ms.ID = r.ID + } + } + } + + page, pageSize := 1, 10 + if input.Page != nil && *input.Page > 0 { + page = *input.Page + } + if input.PageSize != nil && *input.PageSize > 0 { + pageSize = *input.PageSize + } + + // If we are filtering llm or embedding, + // if the list obtained earlier is empty (e.g. llmList), then we don't need to judge the embeddings anymore. + if len(modelServiceList) == 0 && input.ModelType != "" && input.ModelType != modelTypeAll { + return &generated.PaginatedResult{ + HasNextPage: false, + Nodes: []generated.PageNode{}, + TotalCount: 0, + }, nil + } + + // After getting the list of llm's, + // we need to determine if any embedder meets the filter condition and if any embedder can be merged into the modeservice of the llm. + // If an embedder meets the filter condition, we need to determine whether to create a new modelservice or merge it into the modelservice of llm. + // The following logic is the same. + if input.ModelType == "" || input.ModelType == modelTypeAll || input.ModelType == modelTypeLLM { + for idx := range embedderList.Nodes { + v := embedderList.Nodes[idx].(*generated.Embedder) + if input.APIType != nil && *v.Type != *input.APIType { + continue + } + if input.ProviderType != nil && *v.Provider != *input.ProviderType { + continue + } + llm, ok := intersec[v.Name] + if !ok && input.ModelType != "" && input.ModelType != modelTypeAll { + continue + } + if !ok || *llm.APIType != *v.Type || *llm.LlmResource.Provider != *v.Provider { + if !ok { + klog.V(5).Infof("match check llm: embedder %s has no matching llm's, add modelservice", v.Name) + } + if ok && (*llm.APIType != *v.Type || *llm.LlmResource.Provider != *v.Provider) { + klog.V(5).Infof("match check llm: embedder %s type: %s, llm apiType: %s, llm provider: %s, embedder proviers: %s. add modelservice", + v.Name, *v.Type, *llm.APIType, *llm.LlmResource.Provider, *v.Provider) + } + + ms := &generated.ModelService{ + Name: v.Name, + Namespace: input.Namespace, + Creator: v.Creator, + Description: v.Description, + Types: new(string), + CreationTimestamp: v.CreationTimestamp, + UpdateTimestamp: v.UpdateTimestamp, + APIType: v.Type, + LlmResource: new(generated.Llm), + EmbedderResource: v, + Resource: new(generated.Resources), + } + *ms.Types = modelTypeEmbedding + if r, ok := workerResource[v.Name]; ok { + klog.V(5).Infof("match check llm: set modelservice %s resource: %+v", v.Name, r) + *ms.Resource = r.Resources + ms.ID = r.ID + } + + klog.V(5).Infof("match check llm: append only embedder modelService to list: %s", debugModelService(ms)) + modelServiceList = append(modelServiceList, ms) + continue + } + + *llm.Types = modelTypeLLM + "," + modelTypeEmbedding + llm.EmbedderResource = v + if llm.CreationTimestamp.After(*v.CreationTimestamp) { + llm.CreationTimestamp = v.CreationTimestamp + } + if llm.UpdateTimestamp.Before(*v.UpdateTimestamp) { + llm.UpdateTimestamp = v.UpdateTimestamp + } + klog.V(5).Infof("match check llm: embedder match llm %s", debugModelService(llm)) + } + } else { + for idx := range llmList.Nodes { + v := llmList.Nodes[idx].(*generated.Llm) + if input.APIType != nil && *v.Type != *input.APIType { + continue + } + if input.ProviderType != nil && *v.Provider != *input.ProviderType { + continue + } + + embedder, ok := intersec[v.Name] + if !ok && input.ModelType != "" && input.ModelType != modelTypeAll { + continue + } + + if !ok || embedder.Name != v.Name || *embedder.APIType != *v.Type || *embedder.EmbedderResource.Provider != *v.Provider { + if !ok { + klog.V(5).Infof("match check embedder: llm %s has no matching embedder's, add modelservice", v.Name) + } + if ok && (*embedder.APIType != *v.Type || *embedder.EmbedderResource.Provider != *v.Provider) { + klog.V(5).Infof("match check embedder: llm %s type: %s, embedder apiType: %s, embedder provider: %s, llm proviers: %s. add modelservice", + v.Name, *v.Type, *embedder.APIType, *embedder.EmbedderResource.Provider, *v.Provider) + } + + ms := &generated.ModelService{ + Name: v.Name, + Namespace: input.Namespace, + Creator: v.Creator, + Description: v.Description, + Types: new(string), + CreationTimestamp: v.CreationTimestamp, + UpdateTimestamp: v.UpdateTimestamp, + APIType: v.Type, + LlmResource: v, + EmbedderResource: new(generated.Embedder), + Resource: new(generated.Resources), + } + *ms.Types = modelTypeLLM + if r, ok := workerResource[v.Name]; ok { + klog.V(5).Infof("match check embedder: set modelservice %s resource: %+v", v.Name, r) + *ms.Resource = r.Resources + ms.ID = v.ID + } + klog.V(5).Infof("match check embedder: append only llm modelService to list: %s", debugModelService(ms)) + modelServiceList = append(modelServiceList, ms) + continue + } + + *embedder.Types = modelTypeLLM + "," + modelTypeEmbedding + embedder.LlmResource = v + if embedder.CreationTimestamp.After(*v.CreationTimestamp) { + embedder.CreationTimestamp = v.CreationTimestamp + } + if embedder.UpdateTimestamp.Before(*v.UpdateTimestamp) { + embedder.UpdateTimestamp = v.UpdateTimestamp + } + klog.V(5).Infof("match check llm: embedder match llm %s", debugModelService(embedder)) + } + } + + start := (page - 1) * pageSize + end := start + pageSize + total := len(modelServiceList) + + result := pageModelService(start, end, &modelServiceList) + nodes := make([]generated.PageNode, len(result)) + for idx := range result { + nodes[idx] = result[idx] + } + return &generated.PaginatedResult{ + HasNextPage: end < total, + Nodes: nodes, + Page: &page, + PageSize: &pageSize, + TotalCount: total, + }, nil +} + +func GetModelService(ctx context.Context, c dynamic.Interface, name, namespace, apiType string) (*generated.ModelService, error) { + ms := &generated.ModelService{ + ID: new(string), + Name: name, + Namespace: namespace, + Creator: new(string), + Description: new(string), + Types: new(string), + CreationTimestamp: new(time.Time), + UpdateTimestamp: new(time.Time), + APIType: &apiType, + LlmResource: new(generated.Llm), + EmbedderResource: new(generated.Embedder), + Resource: new(generated.Resources), + } + exist := false + if r1, err := llm.ReadLLM(ctx, c, name, namespace); err == nil && *r1.Type == apiType { + exist = true + if r1.CreationTimestamp != nil { + *ms.CreationTimestamp = *r1.CreationTimestamp + } + if r1.UpdateTimestamp != nil { + *ms.UpdateTimestamp = *r1.UpdateTimestamp + } + *ms.Types = "llm" + if r1.Description != nil { + *ms.Description = *r1.Description + } + *ms.LlmResource = *r1 + } + + if r2, err := embedder.ReadEmbedder(ctx, c, name, namespace); err == nil && *r2.Type == apiType { + exist = true + if r2.CreationTimestamp != nil && (ms.CreationTimestamp == nil || r2.CreationTimestamp.Before(*ms.CreationTimestamp)) { + *ms.CreationTimestamp = *r2.CreationTimestamp + } + if r2.UpdateTimestamp != nil && (ms.UpdateTimestamp == nil || r2.UpdateTimestamp.After(*ms.UpdateTimestamp)) { + *ms.UpdateTimestamp = *r2.UpdateTimestamp + } + if *ms.Description == "" && r2.Description != nil { + *ms.Description = *r2.Description + } + if *ms.Types == modelTypeLLM { + *ms.Types += "," + modelTypeEmbedding + } else { + *ms.Types = modelTypeEmbedding + } + *ms.EmbedderResource = *r2 + } + + if r3, err := worker.ReadWorker(ctx, c, name, namespace); err == nil { + *ms.Resource = r3.Resources + *ms.ID = *r3.ID + } + if !exist { + return nil, fmt.Errorf("not found modelService %s", name) + } + return ms, nil +} diff --git a/apiserver/pkg/modelservice/sort.go b/apiserver/pkg/modelservice/sort.go new file mode 100644 index 000000000..8f08d5340 --- /dev/null +++ b/apiserver/pkg/modelservice/sort.go @@ -0,0 +1,72 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package modelservice + +import ( + "container/heap" + + "github.com/kubeagi/arcadia/apiserver/graph/generated" +) + +// Because here the merger of llm and embedders two kinds of data, +// so the amount of data than one kind of quite a bit more, +// so the direct sort may not be an optimal solution, +// we can use the heap, and start, end the two parameters do not arrange all the data, and then find the data to be paged. +type ModelServiceList []*generated.ModelService + +func (m *ModelServiceList) Len() int { + return len(*m) +} + +func (m *ModelServiceList) Less(i, j int) bool { + a := (*m)[i] + b := (*m)[j] + if a.CreationTimestamp.Equal(*b.CreationTimestamp) { + return a.Name < b.Name + } + return a.CreationTimestamp.After(*b.CreationTimestamp) +} + +func (m *ModelServiceList) Swap(i, j int) { + (*m)[i], (*m)[j] = (*m)[j], (*m)[i] +} + +func (m *ModelServiceList) Push(x any) { + *m = append(*m, x.(*generated.ModelService)) +} + +func (m *ModelServiceList) Pop() any { + old := *m + l := len(old) + x := old[l-1] + *m = old[:l-1] + return x +} + +func pageModelService(start, end int, list *[]*generated.ModelService) []*generated.ModelService { + l := ModelServiceList(*list) + heap.Init(&l) + + r := make([]*generated.ModelService, 0) + for cur := 0; cur < end && l.Len() > 0; cur++ { + top := heap.Pop(&l) + if cur < start { + continue + } + r = append(r, top.(*generated.ModelService)) + } + return r +} diff --git a/apiserver/pkg/modelservice/sort_test.go b/apiserver/pkg/modelservice/sort_test.go new file mode 100644 index 000000000..de9657d3e --- /dev/null +++ b/apiserver/pkg/modelservice/sort_test.go @@ -0,0 +1,76 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package modelservice + +import ( + "testing" + "time" + + "github.com/kubeagi/arcadia/apiserver/graph/generated" +) + +type input struct { + start, end int + exp []string +} + +func initData() []*generated.ModelService { + seconds := []int{ + 0, -2, 4, -6, 8, -10, 12, -14, 16, -18, 20, -22, 24, 26, 28, 30, + } + // sorted order: 30,28,26,24,33, 66, 5, 4, 7, 3, 1, 2, 6, 10, 55, 88 + source := []*generated.ModelService{ + {Name: "3"}, {Name: "1"}, {Name: "7"}, + {Name: "2"}, {Name: "4"}, {Name: "6"}, + {Name: "5"}, {Name: "10"}, {Name: "66"}, + {Name: "55"}, {Name: "33"}, {Name: "88"}, + {Name: "24"}, {Name: "26"}, {Name: "28"}, + {Name: "30"}, + } + now := time.Now() + for idx := range source { + tmp := now.Add(time.Second * time.Duration(seconds[idx])) + source[idx].CreationTimestamp = &tmp + } + return source +} + +func TestPagedModelService(t *testing.T) { + for _, tc := range []input{ + {0, 3, []string{"30", "28", "26"}}, + {3, 6, []string{"24", "33", "66"}}, + {6, 9, []string{"5", "4", "7"}}, + {9, 12, []string{"3", "1", "2"}}, + {12, 15, []string{"6", "10", "55"}}, + {15, 18, []string{"88"}}, + {0, 7, []string{"30", "28", "26", "24", "33", "66", "5"}}, + {7, 14, []string{"4", "7", "3", "1", "2", "6", "10"}}, + {14, 18, []string{"55", "88"}}, + {0, 10, []string{"30", "28", "26", "24", "33", "66", "5", "4", "7", "3"}}, + {10, 20, []string{"1", "2", "6", "10", "55", "88"}}, + } { + source := initData() + r := pageModelService(tc.start, tc.end, &source) + if len(tc.exp) != len(r) { + t.Fatalf("expect %d items, got %d", len(tc.exp), len(r)) + } + for idx := range r { + if r[idx].Name != tc.exp[idx] { + t.Fatalf("[%d - %d] expects the i-th to be %s but actually is %s", tc.start, tc.end, tc.exp[idx], r[idx].Name) + } + } + } +} diff --git a/gqlgen.yaml b/gqlgen.yaml index a1fad40e9..56e534540 100644 --- a/gqlgen.yaml +++ b/gqlgen.yaml @@ -246,4 +246,9 @@ models: resolver: true deleteModelService: resolver: true - + ModelServiceQuery: + fields: + getModelService: + resolver: true + listModelServices: + resolver: true