diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index cd254e099..bb5f30aa5 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -423,17 +423,18 @@ type ComplexityRoot struct { ModelService struct { APIType func(childComplexity int) int Annotations func(childComplexity int) int + BaseURL func(childComplexity int) int CreationTimestamp func(childComplexity int) int Creator func(childComplexity int) int Description func(childComplexity int) int DisplayName func(childComplexity int) int - EmbedderResource func(childComplexity int) int ID func(childComplexity int) int Labels func(childComplexity int) int - LlmResource func(childComplexity int) int + Message func(childComplexity int) int Name func(childComplexity int) int Namespace func(childComplexity int) int - Resource func(childComplexity int) int + ProviderType func(childComplexity int) int + Status func(childComplexity int) int Types func(childComplexity int) int UpdateTimestamp func(childComplexity int) int } @@ -446,8 +447,8 @@ type ComplexityRoot struct { ModelServiceQuery struct { CheckModelService func(childComplexity int, input CreateModelServiceInput) int - GetModelService func(childComplexity int, name string, namespace string, apiType string) int - ListModelServices func(childComplexity int, input *ListModelService) int + GetModelService func(childComplexity int, name string, namespace string) int + ListModelServices func(childComplexity int, input *ListModelServiceInput) int } Mutation struct { @@ -675,8 +676,8 @@ type ModelServiceMutationResolver interface { DeleteModelService(ctx context.Context, obj *ModelServiceMutation, input *DeleteCommonInput) (*string, error) } type ModelServiceQueryResolver interface { - GetModelService(ctx context.Context, obj *ModelServiceQuery, name string, namespace string, apiType string) (*ModelService, error) - ListModelServices(ctx context.Context, obj *ModelServiceQuery, input *ListModelService) (*PaginatedResult, error) + GetModelService(ctx context.Context, obj *ModelServiceQuery, name string, namespace string) (*ModelService, error) + ListModelServices(ctx context.Context, obj *ModelServiceQuery, input *ListModelServiceInput) (*PaginatedResult, error) CheckModelService(ctx context.Context, obj *ModelServiceQuery, input CreateModelServiceInput) (*ModelService, error) } type MutationResolver interface { @@ -2597,6 +2598,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.Annotations(childComplexity), true + case "ModelService.baseUrl": + if e.complexity.ModelService.BaseURL == nil { + break + } + + return e.complexity.ModelService.BaseURL(childComplexity), true + case "ModelService.creationTimestamp": if e.complexity.ModelService.CreationTimestamp == nil { break @@ -2625,13 +2633,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.DisplayName(childComplexity), true - case "ModelService.embedderResource": - if e.complexity.ModelService.EmbedderResource == nil { - break - } - - return e.complexity.ModelService.EmbedderResource(childComplexity), true - case "ModelService.id": if e.complexity.ModelService.ID == nil { break @@ -2646,12 +2647,12 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.Labels(childComplexity), true - case "ModelService.llmResource": - if e.complexity.ModelService.LlmResource == nil { + case "ModelService.message": + if e.complexity.ModelService.Message == nil { break } - return e.complexity.ModelService.LlmResource(childComplexity), true + return e.complexity.ModelService.Message(childComplexity), true case "ModelService.name": if e.complexity.ModelService.Name == nil { @@ -2667,12 +2668,19 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.Namespace(childComplexity), true - case "ModelService.resource": - if e.complexity.ModelService.Resource == nil { + case "ModelService.providerType": + if e.complexity.ModelService.ProviderType == nil { + break + } + + return e.complexity.ModelService.ProviderType(childComplexity), true + + case "ModelService.status": + if e.complexity.ModelService.Status == nil { break } - return e.complexity.ModelService.Resource(childComplexity), true + return e.complexity.ModelService.Status(childComplexity), true case "ModelService.types": if e.complexity.ModelService.Types == nil { @@ -2746,7 +2754,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.ModelServiceQuery.GetModelService(childComplexity, args["name"].(string), args["namespace"].(string), args["apiType"].(string)), true + return e.complexity.ModelServiceQuery.GetModelService(childComplexity, args["name"].(string), args["namespace"].(string)), true case "ModelServiceQuery.listModelServices": if e.complexity.ModelServiceQuery.ListModelServices == nil { @@ -2758,7 +2766,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.ModelServiceQuery.ListModelServices(childComplexity, args["input"].(*ListModelService)), true + return e.complexity.ModelServiceQuery.ListModelServices(childComplexity, args["input"].(*ListModelServiceInput)), true case "Mutation.Application": if e.complexity.Mutation.Application == nil { @@ -3518,7 +3526,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputListDatasetInput, ec.unmarshalInputListKnowledgeBaseInput, ec.unmarshalInputListModelInput, - ec.unmarshalInputListModelService, + ec.unmarshalInputListModelServiceInput, ec.unmarshalInputListVersionedDatasetInput, ec.unmarshalInputListWorkerInput, ec.unmarshalInputOssInput, @@ -5205,37 +5213,63 @@ extend type Query { id: String name: String! namespace: String! + labels: Map annotations: Map + creator: String displayName: String description: String + + """ + 模型服务的创建和更新时间 + """ + creationTimestamp: Time + updateTimestamp: Time + + + """ + 模型服务供应商的类型 + 规则: 3rd_party 第三方 + 规则: worker 本地 + """ + providerType: String + + """ 模型服务能力类型,支持 llm 和 embedding 两种模型类型 规则: 如果该模型支持多种模型类型,则可多选。多选后组成的字段通过逗号隔开。如 "llm,embedding" """ types: String - creationTimestamp: Time - updateTimestamp: Time - """ 模型服务 API 类型 - 规则:与 pkgs/llms.LLMType 相同,支持 openai, zhipuai 两种类型 + 规则:支持 openai, zhipuai 两种类型 """ apiType: String + """ - 模型对应的 LLM 及 embedder CR 资源 + 服务地址: 仅针对第三方模型服务 """ - llmResource: LLM - embedderResource: Embedder + baseUrl: String! """ - 第三方的服务不会有这个字段, 只有内部的Worker创建的才会有这个字段。 + 状态 + 规则: 目前分为六种状态 + - True: 正常 (第三方模型服务) + - False: 异常 (第三方模型服务) + - Unknown: 未知 (本地模型服务) + - Pending: 发布中 (本地模型服务) + - Running: 已发布 (本地模型服务) + - Error: 异常 (本地模型服务) """ - resource: Resources + status: String + + """详细的状态消息描述""" + message: String } + input CreateModelServiceInput { """模型服务资源名称(不可同名)""" name: String! @@ -5302,7 +5336,7 @@ input UpdateModelServiceInput { endpoint: EndpointInput! } -input ListModelService { +input ListModelServiceInput { """ 关键词搜索 """ @@ -5312,17 +5346,30 @@ input ListModelService { pageSize: Int """ - all, llm, embedding + 模型服务的类型 + 规则: + - 为空默认不过滤 + - llm 则仅返回LLM模型服务 + - embedding 则仅返回Embedding模型服务 + - llm,embedding 则返回同时提供LLM和Embedding能力的模型服务 """ - modelType: String! + types: String """ - worker, 3rd + 模型服务供应商类型 + 规则: + - 为空默认不过滤 + - worker 则仅返回本地模型服务 + - 3rd_party 则仅返回第三方模型服务 """ providerType: String """ - openai, zhipuai + 模型服务供应商类型 + 规则: + - 为空默认不过滤 + - openai 则仅返回接口类型类型为openai的模型服务 + - zhipuai 则仅返回接口类型类型为zhipuai的模型服务 """ apiType: String } @@ -5338,8 +5385,8 @@ extend type Mutation { } type ModelServiceQuery { - getModelService(name: String!, namespace: String!, apiType: String!): ModelService - listModelServices(input: ListModelService): PaginatedResult! + getModelService(name: String!, namespace: String!): ModelService! + listModelServices(input: ListModelServiceInput): PaginatedResult! checkModelService(input: CreateModelServiceInput!): ModelService! } @@ -6549,25 +6596,16 @@ func (ec *executionContext) field_ModelServiceQuery_getModelService_args(ctx con } } args["namespace"] = arg1 - var arg2 string - if tmp, ok := rawArgs["apiType"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("apiType")) - arg2, err = ec.unmarshalNString2string(ctx, tmp) - if err != nil { - return nil, err - } - } - args["apiType"] = arg2 return args, nil } func (ec *executionContext) field_ModelServiceQuery_listModelServices_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} - var arg0 *ListModelService + var arg0 *ListModelServiceInput if tmp, ok := rawArgs["input"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) - arg0, err = ec.unmarshalOListModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐListModelService(ctx, tmp) + arg0, err = ec.unmarshalOListModelServiceInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐListModelServiceInput(ctx, tmp) if err != nil { return nil, err } @@ -18459,8 +18497,8 @@ func (ec *executionContext) fieldContext_ModelService_description(ctx context.Co return fc, nil } -func (ec *executionContext) _ModelService_types(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_ModelService_types(ctx, field) +func (ec *executionContext) _ModelService_creationTimestamp(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_creationTimestamp(ctx, field) if err != nil { return graphql.Null } @@ -18473,7 +18511,7 @@ func (ec *executionContext) _ModelService_types(ctx context.Context, field graph }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.Types, nil + return obj.CreationTimestamp, nil }) if err != nil { ec.Error(ctx, err) @@ -18482,26 +18520,26 @@ func (ec *executionContext) _ModelService_types(ctx context.Context, field graph if resTmp == nil { return graphql.Null } - res := resTmp.(*string) + res := resTmp.(*time.Time) fc.Result = res - return ec.marshalOString2ᚖstring(ctx, field.Selections, res) + return ec.marshalOTime2ᚖtimeᚐTime(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_ModelService_types(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ModelService_creationTimestamp(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelService", Field: field, IsMethod: false, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - return nil, errors.New("field of type String does not have child fields") + return nil, errors.New("field of type Time does not have child fields") }, } return fc, nil } -func (ec *executionContext) _ModelService_creationTimestamp(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_ModelService_creationTimestamp(ctx, field) +func (ec *executionContext) _ModelService_updateTimestamp(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_updateTimestamp(ctx, field) if err != nil { return graphql.Null } @@ -18514,7 +18552,7 @@ func (ec *executionContext) _ModelService_creationTimestamp(ctx context.Context, }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.CreationTimestamp, nil + return obj.UpdateTimestamp, nil }) if err != nil { ec.Error(ctx, err) @@ -18528,7 +18566,7 @@ func (ec *executionContext) _ModelService_creationTimestamp(ctx context.Context, return ec.marshalOTime2ᚖtimeᚐTime(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_ModelService_creationTimestamp(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ModelService_updateTimestamp(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelService", Field: field, @@ -18541,8 +18579,8 @@ func (ec *executionContext) fieldContext_ModelService_creationTimestamp(ctx cont return fc, nil } -func (ec *executionContext) _ModelService_updateTimestamp(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_ModelService_updateTimestamp(ctx, field) +func (ec *executionContext) _ModelService_providerType(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_providerType(ctx, field) if err != nil { return graphql.Null } @@ -18555,7 +18593,7 @@ func (ec *executionContext) _ModelService_updateTimestamp(ctx context.Context, f }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.UpdateTimestamp, nil + return obj.ProviderType, nil }) if err != nil { ec.Error(ctx, err) @@ -18564,19 +18602,60 @@ func (ec *executionContext) _ModelService_updateTimestamp(ctx context.Context, f if resTmp == nil { return graphql.Null } - res := resTmp.(*time.Time) + res := resTmp.(*string) fc.Result = res - return ec.marshalOTime2ᚖtimeᚐTime(ctx, field.Selections, res) + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_ModelService_updateTimestamp(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ModelService_providerType(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelService", Field: field, IsMethod: false, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - return nil, errors.New("field of type Time does not have child fields") + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _ModelService_types(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_types(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Types, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ModelService_types(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ModelService", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") }, } return fc, nil @@ -18623,8 +18702,8 @@ func (ec *executionContext) fieldContext_ModelService_apiType(ctx context.Contex return fc, nil } -func (ec *executionContext) _ModelService_llmResource(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_ModelService_llmResource(ctx, field) +func (ec *executionContext) _ModelService_baseUrl(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_baseUrl(ctx, field) if err != nil { return graphql.Null } @@ -18637,69 +18716,38 @@ func (ec *executionContext) _ModelService_llmResource(ctx context.Context, field }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.LlmResource, nil + return obj.BaseURL, nil }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*Llm) + res := resTmp.(string) fc.Result = res - return ec.marshalOLLM2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐLlm(ctx, field.Selections, res) + return ec.marshalNString2string(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_ModelService_llmResource(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ModelService_baseUrl(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelService", Field: field, IsMethod: false, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - switch field.Name { - case "id": - return ec.fieldContext_LLM_id(ctx, field) - case "name": - return ec.fieldContext_LLM_name(ctx, field) - case "namespace": - return ec.fieldContext_LLM_namespace(ctx, field) - case "labels": - return ec.fieldContext_LLM_labels(ctx, field) - case "annotations": - return ec.fieldContext_LLM_annotations(ctx, field) - case "creator": - return ec.fieldContext_LLM_creator(ctx, field) - case "displayName": - return ec.fieldContext_LLM_displayName(ctx, field) - case "description": - return ec.fieldContext_LLM_description(ctx, field) - case "baseUrl": - return ec.fieldContext_LLM_baseUrl(ctx, field) - case "models": - return ec.fieldContext_LLM_models(ctx, field) - case "provider": - return ec.fieldContext_LLM_provider(ctx, field) - case "type": - return ec.fieldContext_LLM_type(ctx, field) - case "creationTimestamp": - return ec.fieldContext_LLM_creationTimestamp(ctx, field) - case "updateTimestamp": - return ec.fieldContext_LLM_updateTimestamp(ctx, field) - case "status": - return ec.fieldContext_LLM_status(ctx, field) - case "message": - return ec.fieldContext_LLM_message(ctx, field) - } - return nil, fmt.Errorf("no field named %q was found under type LLM", field.Name) + return nil, errors.New("field of type String does not have child fields") }, } return fc, nil } -func (ec *executionContext) _ModelService_embedderResource(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_ModelService_embedderResource(ctx, field) +func (ec *executionContext) _ModelService_status(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_status(ctx, field) if err != nil { return graphql.Null } @@ -18712,7 +18760,7 @@ func (ec *executionContext) _ModelService_embedderResource(ctx context.Context, }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.EmbedderResource, nil + return obj.Status, nil }) if err != nil { ec.Error(ctx, err) @@ -18721,60 +18769,26 @@ func (ec *executionContext) _ModelService_embedderResource(ctx context.Context, if resTmp == nil { return graphql.Null } - res := resTmp.(*Embedder) + res := resTmp.(*string) fc.Result = res - return ec.marshalOEmbedder2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐEmbedder(ctx, field.Selections, res) + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_ModelService_embedderResource(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ModelService_status(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelService", Field: field, IsMethod: false, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - switch field.Name { - case "id": - return ec.fieldContext_Embedder_id(ctx, field) - case "name": - return ec.fieldContext_Embedder_name(ctx, field) - case "namespace": - return ec.fieldContext_Embedder_namespace(ctx, field) - case "labels": - return ec.fieldContext_Embedder_labels(ctx, field) - case "annotations": - return ec.fieldContext_Embedder_annotations(ctx, field) - case "creator": - return ec.fieldContext_Embedder_creator(ctx, field) - case "displayName": - return ec.fieldContext_Embedder_displayName(ctx, field) - case "description": - return ec.fieldContext_Embedder_description(ctx, field) - case "baseUrl": - return ec.fieldContext_Embedder_baseUrl(ctx, field) - case "models": - return ec.fieldContext_Embedder_models(ctx, field) - case "provider": - return ec.fieldContext_Embedder_provider(ctx, field) - case "type": - return ec.fieldContext_Embedder_type(ctx, field) - case "creationTimestamp": - return ec.fieldContext_Embedder_creationTimestamp(ctx, field) - case "updateTimestamp": - return ec.fieldContext_Embedder_updateTimestamp(ctx, field) - case "status": - return ec.fieldContext_Embedder_status(ctx, field) - case "message": - return ec.fieldContext_Embedder_message(ctx, field) - } - return nil, fmt.Errorf("no field named %q was found under type Embedder", field.Name) + return nil, errors.New("field of type String does not have child fields") }, } return fc, nil } -func (ec *executionContext) _ModelService_resource(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_ModelService_resource(ctx, field) +func (ec *executionContext) _ModelService_message(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_message(ctx, field) if err != nil { return graphql.Null } @@ -18787,7 +18801,7 @@ func (ec *executionContext) _ModelService_resource(ctx context.Context, field gr }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.Resource, nil + return obj.Message, nil }) if err != nil { ec.Error(ctx, err) @@ -18796,27 +18810,19 @@ func (ec *executionContext) _ModelService_resource(ctx context.Context, field gr if resTmp == nil { return graphql.Null } - res := resTmp.(*Resources) + res := resTmp.(*string) fc.Result = res - return ec.marshalOResources2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResources(ctx, field.Selections, res) + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_ModelService_resource(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_ModelService_message(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "ModelService", Field: field, IsMethod: false, IsResolver: false, Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - switch field.Name { - case "cpu": - return ec.fieldContext_Resources_cpu(ctx, field) - case "memory": - return ec.fieldContext_Resources_memory(ctx, field) - case "nvidiaGPU": - return ec.fieldContext_Resources_nvidiaGPU(ctx, field) - } - return nil, fmt.Errorf("no field named %q was found under type Resources", field.Name) + return nil, errors.New("field of type String does not have child fields") }, } return fc, nil @@ -18877,20 +18883,22 @@ func (ec *executionContext) fieldContext_ModelServiceMutation_createModelService return ec.fieldContext_ModelService_displayName(ctx, field) case "description": return ec.fieldContext_ModelService_description(ctx, field) - case "types": - return ec.fieldContext_ModelService_types(ctx, field) case "creationTimestamp": return ec.fieldContext_ModelService_creationTimestamp(ctx, field) case "updateTimestamp": return ec.fieldContext_ModelService_updateTimestamp(ctx, field) + case "providerType": + return ec.fieldContext_ModelService_providerType(ctx, field) + case "types": + return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) - case "llmResource": - return ec.fieldContext_ModelService_llmResource(ctx, field) - case "embedderResource": - return ec.fieldContext_ModelService_embedderResource(ctx, field) - case "resource": - return ec.fieldContext_ModelService_resource(ctx, field) + case "baseUrl": + return ec.fieldContext_ModelService_baseUrl(ctx, field) + case "status": + return ec.fieldContext_ModelService_status(ctx, field) + case "message": + return ec.fieldContext_ModelService_message(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -18964,20 +18972,22 @@ func (ec *executionContext) fieldContext_ModelServiceMutation_updateModelService return ec.fieldContext_ModelService_displayName(ctx, field) case "description": return ec.fieldContext_ModelService_description(ctx, field) - case "types": - return ec.fieldContext_ModelService_types(ctx, field) case "creationTimestamp": return ec.fieldContext_ModelService_creationTimestamp(ctx, field) case "updateTimestamp": return ec.fieldContext_ModelService_updateTimestamp(ctx, field) + case "providerType": + return ec.fieldContext_ModelService_providerType(ctx, field) + case "types": + return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) - case "llmResource": - return ec.fieldContext_ModelService_llmResource(ctx, field) - case "embedderResource": - return ec.fieldContext_ModelService_embedderResource(ctx, field) - case "resource": - return ec.fieldContext_ModelService_resource(ctx, field) + case "baseUrl": + return ec.fieldContext_ModelService_baseUrl(ctx, field) + case "status": + return ec.fieldContext_ModelService_status(ctx, field) + case "message": + return ec.fieldContext_ModelService_message(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -19062,18 +19072,21 @@ func (ec *executionContext) _ModelServiceQuery_getModelService(ctx context.Conte }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.ModelServiceQuery().GetModelService(rctx, obj, fc.Args["name"].(string), fc.Args["namespace"].(string), fc.Args["apiType"].(string)) + return ec.resolvers.ModelServiceQuery().GetModelService(rctx, obj, fc.Args["name"].(string), fc.Args["namespace"].(string)) }) if err != nil { ec.Error(ctx, err) return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } res := resTmp.(*ModelService) fc.Result = res - return ec.marshalOModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx, field.Selections, res) + return ec.marshalNModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_ModelServiceQuery_getModelService(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -19100,20 +19113,22 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_getModelService(ctx c return ec.fieldContext_ModelService_displayName(ctx, field) case "description": return ec.fieldContext_ModelService_description(ctx, field) - case "types": - return ec.fieldContext_ModelService_types(ctx, field) case "creationTimestamp": return ec.fieldContext_ModelService_creationTimestamp(ctx, field) case "updateTimestamp": return ec.fieldContext_ModelService_updateTimestamp(ctx, field) + case "providerType": + return ec.fieldContext_ModelService_providerType(ctx, field) + case "types": + return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) - case "llmResource": - return ec.fieldContext_ModelService_llmResource(ctx, field) - case "embedderResource": - return ec.fieldContext_ModelService_embedderResource(ctx, field) - case "resource": - return ec.fieldContext_ModelService_resource(ctx, field) + case "baseUrl": + return ec.fieldContext_ModelService_baseUrl(ctx, field) + case "status": + return ec.fieldContext_ModelService_status(ctx, field) + case "message": + return ec.fieldContext_ModelService_message(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -19146,7 +19161,7 @@ func (ec *executionContext) _ModelServiceQuery_listModelServices(ctx context.Con }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.ModelServiceQuery().ListModelServices(rctx, obj, fc.Args["input"].(*ListModelService)) + return ec.resolvers.ModelServiceQuery().ListModelServices(rctx, obj, fc.Args["input"].(*ListModelServiceInput)) }) if err != nil { ec.Error(ctx, err) @@ -19254,20 +19269,22 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_checkModelService(ctx return ec.fieldContext_ModelService_displayName(ctx, field) case "description": return ec.fieldContext_ModelService_description(ctx, field) - case "types": - return ec.fieldContext_ModelService_types(ctx, field) case "creationTimestamp": return ec.fieldContext_ModelService_creationTimestamp(ctx, field) case "updateTimestamp": return ec.fieldContext_ModelService_updateTimestamp(ctx, field) + case "providerType": + return ec.fieldContext_ModelService_providerType(ctx, field) + case "types": + return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) - case "llmResource": - return ec.fieldContext_ModelService_llmResource(ctx, field) - case "embedderResource": - return ec.fieldContext_ModelService_embedderResource(ctx, field) - case "resource": - return ec.fieldContext_ModelService_resource(ctx, field) + case "baseUrl": + return ec.fieldContext_ModelService_baseUrl(ctx, field) + case "status": + return ec.fieldContext_ModelService_status(ctx, field) + case "message": + return ec.fieldContext_ModelService_message(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) }, @@ -27723,14 +27740,14 @@ func (ec *executionContext) unmarshalInputListModelInput(ctx context.Context, ob return it, nil } -func (ec *executionContext) unmarshalInputListModelService(ctx context.Context, obj interface{}) (ListModelService, error) { - var it ListModelService +func (ec *executionContext) unmarshalInputListModelServiceInput(ctx context.Context, obj interface{}) (ListModelServiceInput, error) { + var it ListModelServiceInput asMap := map[string]interface{}{} for k, v := range obj.(map[string]interface{}) { asMap[k] = v } - fieldsInOrder := [...]string{"keyword", "namespace", "page", "pageSize", "modelType", "providerType", "apiType"} + fieldsInOrder := [...]string{"keyword", "namespace", "page", "pageSize", "types", "providerType", "apiType"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -27773,15 +27790,15 @@ func (ec *executionContext) unmarshalInputListModelService(ctx context.Context, return it, err } it.PageSize = data - case "modelType": + case "types": var err error - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("modelType")) - data, err := ec.unmarshalNString2string(ctx, v) + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("types")) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) if err != nil { return it, err } - it.ModelType = data + it.Types = data case "providerType": var err error @@ -32589,20 +32606,25 @@ func (ec *executionContext) _ModelService(ctx context.Context, sel ast.Selection out.Values[i] = ec._ModelService_displayName(ctx, field, obj) case "description": out.Values[i] = ec._ModelService_description(ctx, field, obj) - case "types": - out.Values[i] = ec._ModelService_types(ctx, field, obj) case "creationTimestamp": out.Values[i] = ec._ModelService_creationTimestamp(ctx, field, obj) case "updateTimestamp": out.Values[i] = ec._ModelService_updateTimestamp(ctx, field, obj) + case "providerType": + out.Values[i] = ec._ModelService_providerType(ctx, field, obj) + case "types": + out.Values[i] = ec._ModelService_types(ctx, field, obj) case "apiType": out.Values[i] = ec._ModelService_apiType(ctx, field, obj) - case "llmResource": - out.Values[i] = ec._ModelService_llmResource(ctx, field, obj) - case "embedderResource": - out.Values[i] = ec._ModelService_embedderResource(ctx, field, obj) - case "resource": - out.Values[i] = ec._ModelService_resource(ctx, field, obj) + case "baseUrl": + out.Values[i] = ec._ModelService_baseUrl(ctx, field, obj) + if out.Values[i] == graphql.Null { + out.Invalids++ + } + case "status": + out.Values[i] = ec._ModelService_status(ctx, field, obj) + case "message": + out.Values[i] = ec._ModelService_message(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -32786,6 +32808,9 @@ func (ec *executionContext) _ModelServiceQuery(ctx context.Context, sel ast.Sele } }() res = ec._ModelServiceQuery_getModelService(ctx, field, obj) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } return res } @@ -35893,13 +35918,6 @@ func (ec *executionContext) unmarshalODeleteDataProcessInput2ᚖgithubᚗcomᚋk return &res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) marshalOEmbedder2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐEmbedder(ctx context.Context, sel ast.SelectionSet, v *Embedder) graphql.Marshaler { - if v == nil { - return graphql.Null - } - return ec._Embedder(ctx, sel, v) -} - func (ec *executionContext) marshalOEmbedderMutation2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐEmbedderMutation(ctx context.Context, sel ast.SelectionSet, v *EmbedderMutation) graphql.Marshaler { if v == nil { return graphql.Null @@ -36051,13 +36069,6 @@ func (ec *executionContext) marshalOKnowledgeBaseQuery2ᚖgithubᚗcomᚋkubeagi return ec._KnowledgeBaseQuery(ctx, sel, v) } -func (ec *executionContext) marshalOLLM2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐLlm(ctx context.Context, sel ast.SelectionSet, v *Llm) graphql.Marshaler { - if v == nil { - return graphql.Null - } - return ec._LLM(ctx, sel, v) -} - func (ec *executionContext) marshalOLLMConfig2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐLLMConfig(ctx context.Context, sel ast.SelectionSet, v *LLMConfig) graphql.Marshaler { if v == nil { return graphql.Null @@ -36088,11 +36099,11 @@ func (ec *executionContext) unmarshalOListDatasetInput2ᚖgithubᚗcomᚋkubeagi return &res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) unmarshalOListModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐListModelService(ctx context.Context, v interface{}) (*ListModelService, error) { +func (ec *executionContext) unmarshalOListModelServiceInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐListModelServiceInput(ctx context.Context, v interface{}) (*ListModelServiceInput, error) { if v == nil { return nil, nil } - res, err := ec.unmarshalInputListModelService(ctx, v) + res, err := ec.unmarshalInputListModelServiceInput(ctx, v) return &res, graphql.ErrorOnPath(ctx, err) } @@ -36126,13 +36137,6 @@ func (ec *executionContext) marshalOModelQuery2ᚖgithubᚗcomᚋkubeagiᚋarcad return ec._ModelQuery(ctx, sel, v) } -func (ec *executionContext) marshalOModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx context.Context, sel ast.SelectionSet, v *ModelService) graphql.Marshaler { - if v == nil { - return graphql.Null - } - return ec._ModelService(ctx, sel, v) -} - func (ec *executionContext) marshalOModelServiceMutation2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelServiceMutation(ctx context.Context, sel ast.SelectionSet, v *ModelServiceMutation) graphql.Marshaler { if v == nil { return graphql.Null @@ -36216,13 +36220,6 @@ func (ec *executionContext) marshalOPaginatedDataProcessItem2ᚖgithubᚗcomᚋk return ec._PaginatedDataProcessItem(ctx, sel, v) } -func (ec *executionContext) marshalOResources2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResources(ctx context.Context, sel ast.SelectionSet, v *Resources) graphql.Marshaler { - if v == nil { - return graphql.Null - } - return ec._Resources(ctx, sel, v) -} - func (ec *executionContext) unmarshalOResourcesInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐResourcesInput(ctx context.Context, v interface{}) (*ResourcesInput, error) { if v == nil { return nil, nil diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index fafe01c0c..90f5619cb 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -882,17 +882,30 @@ type ListModelInput struct { PageSize *int `json:"pageSize,omitempty"` } -type ListModelService struct { +type ListModelServiceInput struct { // 关键词搜索 Keyword *string `json:"keyword,omitempty"` Namespace string `json:"namespace"` Page *int `json:"page,omitempty"` PageSize *int `json:"pageSize,omitempty"` - // all, llm, embedding - ModelType string `json:"modelType"` - // worker, 3rd + // 模型服务的类型 + // 规则: + // - 为空默认不过滤 + // - llm 则仅返回LLM模型服务 + // - embedding 则仅返回Embedding模型服务 + // - llm,embedding 则返回同时提供LLM和Embedding能力的模型服务 + Types *string `json:"types,omitempty"` + // 模型服务供应商类型 + // 规则: + // - 为空默认不过滤 + // - worker 则仅返回本地模型服务 + // - 3rd_party 则仅返回第三方模型服务 ProviderType *string `json:"providerType,omitempty"` - // openai, zhipuai + // 模型服务供应商类型 + // 规则: + // - 为空默认不过滤 + // - openai 则仅返回接口类型类型为openai的模型服务 + // - zhipuai 则仅返回接口类型类型为zhipuai的模型服务 APIType *string `json:"apiType,omitempty"` } @@ -992,19 +1005,32 @@ type ModelService struct { Creator *string `json:"creator,omitempty"` DisplayName *string `json:"displayName,omitempty"` Description *string `json:"description,omitempty"` - // 模型服务能力类型,支持 llm 和 embedding 两种模型类型 - // 规则: 如果该模型支持多种模型类型,则可多选。多选后组成的字段通过逗号隔开。如 "llm,embedding" - Types *string `json:"types,omitempty"` + // 模型服务的创建和更新时间 CreationTimestamp *time.Time `json:"creationTimestamp,omitempty"` UpdateTimestamp *time.Time `json:"updateTimestamp,omitempty"` + // 模型服务供应商的类型 + // 规则: 3rd_party 第三方 + // 规则: worker 本地 + ProviderType *string `json:"providerType,omitempty"` + // 模型服务能力类型,支持 llm 和 embedding 两种模型类型 + // 规则: 如果该模型支持多种模型类型,则可多选。多选后组成的字段通过逗号隔开。如 "llm,embedding" + Types *string `json:"types,omitempty"` // 模型服务 API 类型 - // 规则:与 pkgs/llms.LLMType 相同,支持 openai, zhipuai 两种类型 + // 规则:支持 openai, zhipuai 两种类型 APIType *string `json:"apiType,omitempty"` - // 模型对应的 LLM 及 embedder CR 资源 - LlmResource *Llm `json:"llmResource,omitempty"` - EmbedderResource *Embedder `json:"embedderResource,omitempty"` - // 第三方的服务不会有这个字段, 只有内部的Worker创建的才会有这个字段。 - Resource *Resources `json:"resource,omitempty"` + // 服务地址: 仅针对第三方模型服务 + BaseURL string `json:"baseUrl"` + // 状态 + // 规则: 目前分为六种状态 + // - True: 正常 (第三方模型服务) + // - False: 异常 (第三方模型服务) + // - Unknown: 未知 (本地模型服务) + // - Pending: 发布中 (本地模型服务) + // - Running: 已发布 (本地模型服务) + // - Error: 异常 (本地模型服务) + Status *string `json:"status,omitempty"` + // 详细的状态消息描述 + Message *string `json:"message,omitempty"` } func (ModelService) IsPageNode() {} @@ -1016,7 +1042,7 @@ type ModelServiceMutation struct { } type ModelServiceQuery struct { - GetModelService *ModelService `json:"getModelService,omitempty"` + GetModelService ModelService `json:"getModelService"` ListModelServices PaginatedResult `json:"listModelServices"` CheckModelService ModelService `json:"checkModelService"` } diff --git a/apiserver/graph/impl/modelservice.resolvers.go b/apiserver/graph/impl/modelservice.resolvers.go index a61b31d24..27cea7b3e 100644 --- a/apiserver/graph/impl/modelservice.resolvers.go +++ b/apiserver/graph/impl/modelservice.resolvers.go @@ -39,16 +39,16 @@ func (r *modelServiceMutationResolver) DeleteModelService(ctx context.Context, o } // GetModelService is the resolver for the getModelService field. -func (r *modelServiceQueryResolver) GetModelService(ctx context.Context, obj *generated.ModelServiceQuery, name string, namespace string, apiType string) (*generated.ModelService, error) { +func (r *modelServiceQueryResolver) GetModelService(ctx context.Context, obj *generated.ModelServiceQuery, name string, namespace string) (*generated.ModelService, error) { c, err := getClientFromCtx(ctx) if err != nil { return nil, err } - return modelservice.GetModelService(ctx, c, name, namespace, apiType) + return modelservice.ReadModelService(ctx, c, name, namespace) } // ListModelServices is the resolver for the listModelServices field. -func (r *modelServiceQueryResolver) ListModelServices(ctx context.Context, obj *generated.ModelServiceQuery, input *generated.ListModelService) (*generated.PaginatedResult, error) { +func (r *modelServiceQueryResolver) ListModelServices(ctx context.Context, obj *generated.ModelServiceQuery, input *generated.ListModelServiceInput) (*generated.PaginatedResult, error) { c, err := getClientFromCtx(ctx) if err != nil { return nil, err diff --git a/apiserver/graph/schema/modelservice.gql b/apiserver/graph/schema/modelservice.gql index c2a57ab22..1dc1c5c3d 100644 --- a/apiserver/graph/schema/modelservice.gql +++ b/apiserver/graph/schema/modelservice.gql @@ -9,45 +9,14 @@ mutation createModelService($input: CreateModelServiceInput!) { creator displayName description + providerType types apiType creationTimestamp updateTimestamp - llmResource { - name - namespace - labels - annotations - displayName - description - baseUrl - models - provider - type - updateTimestamp - status - message - } - embedderResource { - name - namespace - labels - annotations - displayName - description - type - baseUrl - models - provider - updateTimestamp - status - message - } - resource { - cpu - nvidiaGPU - memory - } + status + message + baseUrl } } } @@ -63,45 +32,14 @@ mutation updateModelService($input: UpdateModelServiceInput) { creator displayName description + providerType types apiType creationTimestamp updateTimestamp - llmResource { - name - namespace - labels - annotations - displayName - description - baseUrl - models - provider - type - updateTimestamp - status - message - } - embedderResource { - name - namespace - labels - annotations - displayName - description - type - baseUrl - models - provider - updateTimestamp - status - message - } - resource { - cpu - nvidiaGPU - memory - } + status + message + baseUrl } } } @@ -112,9 +50,9 @@ mutation deleteModelServices($input: DeleteCommonInput) { } } -query getModelService($name: String!, $namespace: String!, $apiType: String!) { +query getModelService($name: String!, $namespace: String!) { ModelService { - getModelService(name: $name, namespace: $namespace, apiType: $apiType) { + getModelService(name: $name, namespace: $namespace) { id name namespace @@ -123,50 +61,19 @@ query getModelService($name: String!, $namespace: String!, $apiType: String!) { creator displayName description + providerType types apiType creationTimestamp - updateTimestamp - llmResource { - name - namespace - labels - annotations - displayName - description - baseUrl - models - provider - type - updateTimestamp - status - message - } - embedderResource{ - name - namespace - labels - annotations - displayName - description - baseUrl - models - provider - type - updateTimestamp - status - message - } - resource { - cpu - memory - nvidiaGPU - } + updateTimestamp + status + message + baseUrl } } } -query listModelServices($input: ListModelService) { +query listModelServices($input: ListModelServiceInput) { ModelService { listModelServices(input: $input) { totalCount @@ -182,45 +89,14 @@ query listModelServices($input: ListModelService) { creator displayName description + providerType types apiType creationTimestamp - updateTimestamp - llmResource { - name - namespace - labels - annotations - displayName - description - baseUrl - models - provider - type - updateTimestamp - status - message - } - embedderResource{ - name - namespace - labels - annotations - displayName - description - baseUrl - models - provider - type - updateTimestamp - status - message - } - resource { - cpu - memory - nvidiaGPU - } + updateTimestamp + status + message + baseUrl } } } diff --git a/apiserver/graph/schema/modelservice.graphqls b/apiserver/graph/schema/modelservice.graphqls index e5c689de0..081498e13 100644 --- a/apiserver/graph/schema/modelservice.graphqls +++ b/apiserver/graph/schema/modelservice.graphqls @@ -2,37 +2,63 @@ type ModelService { id: String name: String! namespace: String! + labels: Map annotations: Map + creator: String displayName: String description: String + + """ + 模型服务的创建和更新时间 + """ + creationTimestamp: Time + updateTimestamp: Time + + + """ + 模型服务供应商的类型 + 规则: 3rd_party 第三方 + 规则: worker 本地 + """ + providerType: String + + """ 模型服务能力类型,支持 llm 和 embedding 两种模型类型 规则: 如果该模型支持多种模型类型,则可多选。多选后组成的字段通过逗号隔开。如 "llm,embedding" """ types: String - creationTimestamp: Time - updateTimestamp: Time - """ 模型服务 API 类型 - 规则:与 pkgs/llms.LLMType 相同,支持 openai, zhipuai 两种类型 + 规则:支持 openai, zhipuai 两种类型 """ apiType: String + """ - 模型对应的 LLM 及 embedder CR 资源 + 服务地址: 仅针对第三方模型服务 """ - llmResource: LLM - embedderResource: Embedder + baseUrl: String! """ - 第三方的服务不会有这个字段, 只有内部的Worker创建的才会有这个字段。 + 状态 + 规则: 目前分为六种状态 + - True: 正常 (第三方模型服务) + - False: 异常 (第三方模型服务) + - Unknown: 未知 (本地模型服务) + - Pending: 发布中 (本地模型服务) + - Running: 已发布 (本地模型服务) + - Error: 异常 (本地模型服务) """ - resource: Resources + status: String + + """详细的状态消息描述""" + message: String } + input CreateModelServiceInput { """模型服务资源名称(不可同名)""" name: String! @@ -99,7 +125,7 @@ input UpdateModelServiceInput { endpoint: EndpointInput! } -input ListModelService { +input ListModelServiceInput { """ 关键词搜索 """ @@ -109,17 +135,30 @@ input ListModelService { pageSize: Int """ - all, llm, embedding + 模型服务的类型 + 规则: + - 为空默认不过滤 + - llm 则仅返回LLM模型服务 + - embedding 则仅返回Embedding模型服务 + - llm,embedding 则返回同时提供LLM和Embedding能力的模型服务 """ - modelType: String! + types: String """ - worker, 3rd + 模型服务供应商类型 + 规则: + - 为空默认不过滤 + - worker 则仅返回本地模型服务 + - 3rd_party 则仅返回第三方模型服务 """ providerType: String """ - openai, zhipuai + 模型服务供应商类型 + 规则: + - 为空默认不过滤 + - openai 则仅返回接口类型类型为openai的模型服务 + - zhipuai 则仅返回接口类型类型为zhipuai的模型服务 """ apiType: String } @@ -135,8 +174,8 @@ extend type Mutation { } type ModelServiceQuery { - getModelService(name: String!, namespace: String!, apiType: String!): ModelService - listModelServices(input: ListModelService): PaginatedResult! + getModelService(name: String!, namespace: String!): ModelService! + listModelServices(input: ListModelServiceInput): PaginatedResult! checkModelService(input: CreateModelServiceInput!): ModelService! } diff --git a/apiserver/pkg/common/common.go b/apiserver/pkg/common/common.go index 9dd8ca8b9..5f2764885 100644 --- a/apiserver/pkg/common/common.go +++ b/apiserver/pkg/common/common.go @@ -40,6 +40,13 @@ var ( StatusFalse = "False" ) +// ModelType +var ( + ModelTypeAll = "llm,embedding" + ModelTypeLLM = "llm" + ModelTypeEmbedding = "embedding" +) + // Resource operations // ResourceGet provides a common way to get a resource @@ -146,3 +153,42 @@ func GetObjStatus(obj client.Object) string { return string(condition.Status) } + +// PageNodeConvertFunc convert `any` to a `PageNode` +type PageNodeConvertFunc func(any) generated.PageNode + +var ( + DefaultPageNodeConvertFunc = func(node any) generated.PageNode { + pageNode, ok := node.(generated.PageNode) + if !ok { + return nil + } + return pageNode + } +) + +var ( + // UnlimitedPageSize which means all + UnlimitedPageSize = -1 +) + +// ListOptions for graphql list +type ListOptions struct { + ConvertFunc PageNodeConvertFunc +} + +// DefaultListOptions initialize a ListOptions with default settings +func DefaultListOptions() *ListOptions { + return &ListOptions{ + ConvertFunc: DefaultPageNodeConvertFunc, + } +} + +type ListOptionsFunc func(options *ListOptions) + +// WithPageNodeConvertFunc update the PageNodeConvertFunc +func WithPageNodeConvertFunc(convertFunc PageNodeConvertFunc) ListOptionsFunc { + return func(option *ListOptions) { + option.ConvertFunc = convertFunc + } +} diff --git a/apiserver/pkg/embedder/embedder.go b/apiserver/pkg/embedder/embedder.go index 3ff0118ca..f9b90209a 100644 --- a/apiserver/pkg/embedder/embedder.go +++ b/apiserver/pkg/embedder/embedder.go @@ -31,10 +31,12 @@ import ( "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/common" graphqlutils "github.com/kubeagi/arcadia/apiserver/pkg/utils" + "github.com/kubeagi/arcadia/apiserver/pkg/worker" "github.com/kubeagi/arcadia/pkg/embeddings" "github.com/kubeagi/arcadia/pkg/utils" ) +// Embedder2model convert unstructured `CR Embedder` to graphql model `Embedder` func Embedder2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) *generated.Embedder { embedder := &v1alpha1.Embedder{} if err := utils.UnstructuredToStructured(obj, embedder); err != nil { @@ -44,16 +46,22 @@ func Embedder2model(ctx context.Context, c dynamic.Interface, obj *unstructured. id := string(embedder.GetUID()) creationtimestamp := embedder.GetCreationTimestamp().Time - servicetype := string(embedder.Spec.Type) + embedderType := string(embedder.Spec.Type) + provider := string(embedder.Spec.Provider.GetType()) // conditioned status condition := embedder.Status.GetCondition(v1alpha1.TypeReady) updateTime := condition.LastTransitionTime.Time status := common.GetObjStatus(embedder) message := string(condition.Message) - - // provider type - provider := string(embedder.Spec.Provider.GetType()) + // Use worker's status&message if LLM's provider is `Worker` + if embedder.Spec.Provider.GetType() == v1alpha1.ProviderTypeWorker { + w, err := worker.ReadWorker(ctx, c, embedder.Name, embedder.Namespace) + if err == nil { + status = *w.Status + message = *w.Message + } + } // get embedder's api url var baseURL string @@ -73,7 +81,7 @@ func Embedder2model(ctx context.Context, c dynamic.Interface, obj *unstructured. Annotations: graphqlutils.MapStr2Any(obj.GetAnnotations()), DisplayName: &embedder.Spec.DisplayName, Description: &embedder.Spec.Description, - Type: &servicetype, + Type: &embedderType, Provider: &provider, BaseURL: baseURL, Models: embedder.GetModelList(), @@ -214,7 +222,13 @@ func DeleteEmbedders(ctx context.Context, c dynamic.Interface, input *generated. return nil, nil } -func ListEmbedders(ctx context.Context, c dynamic.Interface, input generated.ListCommonInput) (*generated.PaginatedResult, error) { +func ListEmbedders(ctx context.Context, c dynamic.Interface, input generated.ListCommonInput, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) { + // listOpts in this graphql query + opts := common.DefaultListOptions() + for _, optFunc := range listOpts { + optFunc(opts) + } + keyword, labelSelector, fieldSelector := "", "", "" page, pageSize := 1, 10 if input.Keyword != nil { @@ -248,6 +262,12 @@ func ListEmbedders(ctx context.Context, c dynamic.Interface, input generated.Lis totalCount := len(us.Items) + // if pageSize is -1 which means unlimited pagesize,return all + if pageSize == common.UnlimitedPageSize { + page = 1 + pageSize = totalCount + } + result := make([]generated.PageNode, 0, pageSize) pageStart := (page - 1) * pageSize for index, u := range us.Items { @@ -262,7 +282,7 @@ func ListEmbedders(ctx context.Context, c dynamic.Interface, input generated.Lis continue } } - result = append(result, m) + result = append(result, opts.ConvertFunc(m)) // break if page size matches if len(result) == pageSize { diff --git a/apiserver/pkg/llm/llm.go b/apiserver/pkg/llm/llm.go index d87139e7f..be3924158 100644 --- a/apiserver/pkg/llm/llm.go +++ b/apiserver/pkg/llm/llm.go @@ -31,10 +31,12 @@ import ( "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/common" graphqlutils "github.com/kubeagi/arcadia/apiserver/pkg/utils" + "github.com/kubeagi/arcadia/apiserver/pkg/worker" "github.com/kubeagi/arcadia/pkg/llms" "github.com/kubeagi/arcadia/pkg/utils" ) +// LLM2model convert unstructured `CR LLM` to graphql model `Llm` func LLM2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) *generated.Llm { llm := &v1alpha1.LLM{} if err := utils.UnstructuredToStructured(obj, llm); err != nil { @@ -44,14 +46,26 @@ func LLM2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstr id := string(llm.GetUID()) creationtimestamp := llm.GetCreationTimestamp().Time + llmType := string(llm.Spec.Type) + provider := string(llm.Spec.Provider.GetType()) + // conditioned status condition := llm.Status.GetCondition(v1alpha1.TypeReady) updateTime := condition.LastTransitionTime.Time status := common.GetObjStatus(llm) message := string(condition.Message) - - llmType := string(llm.Spec.Type) - provider := string(llm.Spec.Provider.GetType()) + // Use worker's status&message if LLM's provider is `Worker` + if llm.Spec.Provider.GetType() == v1alpha1.ProviderTypeWorker { + w, err := worker.ReadWorker(ctx, c, llm.Name, llm.Namespace) + if err == nil { + if w.Status != nil { + status = *w.Status + } + if w.Message != nil { + message = *w.Message + } + } + } // get llm's api url var baseURL string @@ -82,7 +96,13 @@ func LLM2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstr return &md } -func ListLLMs(ctx context.Context, c dynamic.Interface, input generated.ListCommonInput) (*generated.PaginatedResult, error) { +// ListLLMs return a list of LLMs based on input params +func ListLLMs(ctx context.Context, c dynamic.Interface, input generated.ListCommonInput, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) { + opts := common.DefaultListOptions() + for _, optFunc := range listOpts { + optFunc(opts) + } + keyword, labelSelector, fieldSelector := "", "", "" page, pageSize := 1, 10 if input.Keyword != nil { @@ -101,12 +121,10 @@ func ListLLMs(ctx context.Context, c dynamic.Interface, input generated.ListComm pageSize = *input.PageSize } - listOptions := metav1.ListOptions{ + us, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "LLM")).Namespace(input.Namespace).List(ctx, metav1.ListOptions{ LabelSelector: labelSelector, FieldSelector: fieldSelector, - } - - us, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "LLM")).Namespace(input.Namespace).List(ctx, listOptions) + }) if err != nil { return nil, err } @@ -117,6 +135,12 @@ func ListLLMs(ctx context.Context, c dynamic.Interface, input generated.ListComm totalCount := len(us.Items) + // if pageSize is -1 which means unlimited pagesize,return all + if pageSize == common.UnlimitedPageSize { + page = 1 + pageSize = totalCount + } + result := make([]generated.PageNode, 0, pageSize) pageStart := (page - 1) * pageSize @@ -132,7 +156,9 @@ func ListLLMs(ctx context.Context, c dynamic.Interface, input generated.ListComm continue } } - result = append(result, m) + + // convertFunc + result = append(result, opts.ConvertFunc(m)) // break if page size matches if len(result) == pageSize { diff --git a/apiserver/pkg/modelservice/helper.go b/apiserver/pkg/modelservice/helper.go new file mode 100644 index 000000000..7d6b8b71f --- /dev/null +++ b/apiserver/pkg/modelservice/helper.go @@ -0,0 +1,78 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package modelservice + +import ( + "github.com/kubeagi/arcadia/apiserver/graph/generated" + "github.com/kubeagi/arcadia/apiserver/pkg/common" +) + +// Embedder2ModelService convert unstructured `CR Embedder` to graphql model `ModelService` +func Embedder2ModelService(embedder *generated.Embedder) *generated.ModelService { + ms := &generated.ModelService{ + // metadata + ID: embedder.ID, + Name: embedder.Name, + Namespace: embedder.Namespace, + CreationTimestamp: embedder.CreationTimestamp, + UpdateTimestamp: embedder.UpdateTimestamp, + // common + Creator: embedder.Creator, + DisplayName: embedder.DisplayName, + Description: embedder.Description, + // ProviderType: worker or 3rd_party + ProviderType: embedder.Provider, + // Model types: llm or embedding + Types: &common.ModelTypeEmbedding, + // APIType of this modelservice + APIType: embedder.Type, + // BaseURL of this Embedder + BaseURL: embedder.BaseURL, + // Statuds of this model service + Status: embedder.Status, + Message: embedder.Message, + } + return ms +} + +// LLM2ModelService convert unstructured `CR LLM` to graphql model `ModelService` +func LLM2ModelService(llm *generated.Llm) *generated.ModelService { + ms := &generated.ModelService{ + // metadata + ID: llm.ID, + Name: llm.Name, + Namespace: llm.Namespace, + CreationTimestamp: llm.CreationTimestamp, + UpdateTimestamp: llm.UpdateTimestamp, + // common + Creator: llm.Creator, + DisplayName: llm.DisplayName, + Description: llm.Description, + // ProviderType: worker or 3rd_party + ProviderType: llm.Provider, + // Model types: llm or embedding + Types: &common.ModelTypeLLM, + // APIType of this modelservice + APIType: llm.Type, + // BaseURL of this Embedder + BaseURL: llm.BaseURL, + // Statuds of this model service + Status: llm.Status, + Message: llm.Message, + } + return ms +} diff --git a/apiserver/pkg/modelservice/modelservice.go b/apiserver/pkg/modelservice/modelservice.go index c73f3ff12..900b67054 100644 --- a/apiserver/pkg/modelservice/modelservice.go +++ b/apiserver/pkg/modelservice/modelservice.go @@ -20,29 +20,26 @@ import ( "context" "errors" "fmt" + "sort" "strings" "time" "k8s.io/client-go/dynamic" - "k8s.io/klog/v2" "github.com/kubeagi/arcadia/apiserver/graph/generated" + "github.com/kubeagi/arcadia/apiserver/pkg/common" "github.com/kubeagi/arcadia/apiserver/pkg/embedder" "github.com/kubeagi/arcadia/apiserver/pkg/llm" - "github.com/kubeagi/arcadia/apiserver/pkg/worker" "github.com/kubeagi/arcadia/pkg/llms/openai" "github.com/kubeagi/arcadia/pkg/llms/zhipuai" ) +// CreateModelService creates a 3rd_party model service +// If serviceType is llm,embedding,then a LLM and a Embedder will be created +// - Wrap all elements into *generated.ModelService func CreateModelService(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (*generated.ModelService, error) { // - Get general info from input: displayName, description, types, name & namespace, etc. - // - Create *generated.LLM, *generated.embedder accordingly - // - Wrap all elements into *generated.ModelService displayName, description, serviceType, APIType := "", "", "", "" - var genLLM *generated.Llm - var genEmbed *generated.Embedder - var creationTimestamp, updateTimestamp *time.Time - var err error if input.DisplayName != nil { displayName = *input.DisplayName @@ -57,8 +54,11 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate APIType = *input.APIType } + var modelSerivce = &generated.ModelService{} + + // Create LLM if serviceType contains llm if strings.Contains(serviceType, "llm") { - genLLM, err = llm.CreateLLM(ctx, c, generated.CreateLLMInput{ + llm, err := llm.CreateLLM(ctx, c, generated.CreateLLMInput{ Name: input.Name, Namespace: input.Namespace, DisplayName: &displayName, @@ -71,10 +71,12 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate if err != nil { return nil, err } + modelSerivce = LLM2ModelService(llm) } + // Create Embedder if serviceType contains embedding if strings.Contains(serviceType, "embedding") { - genEmbed, err = embedder.CreateEmbedder(ctx, c, generated.CreateEmbedderInput{ + embedder, err := embedder.CreateEmbedder(ctx, c, generated.CreateEmbedderInput{ Name: input.Name, Namespace: input.Namespace, DisplayName: &displayName, @@ -87,35 +89,16 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate if err != nil { return nil, err } - } - if genLLM != nil { - creationTimestamp = genLLM.CreationTimestamp - updateTimestamp = genLLM.UpdateTimestamp - } else if genEmbed != nil { - creationTimestamp = genEmbed.CreationTimestamp - updateTimestamp = genEmbed.UpdateTimestamp + modelSerivce = Embedder2ModelService(embedder) } - ms := generated.ModelService{ - // fulfill all params - // TBD: ID, Creator, Resource - Name: input.Name, - Namespace: input.Namespace, - DisplayName: &displayName, - Description: &description, - Labels: input.Labels, - Annotations: input.Annotations, - Types: &serviceType, - APIType: &APIType, - CreationTimestamp: creationTimestamp, - UpdateTimestamp: updateTimestamp, - LlmResource: genLLM, - EmbedderResource: genEmbed, - } - return &ms, nil + modelSerivce.Types = &serviceType + + return modelSerivce, nil } +// UpdateModelService updates a 3rd_party model service func UpdateModelService(ctx context.Context, c dynamic.Interface, input generated.UpdateModelServiceInput) (*generated.ModelService, error) { name, namespace, displayName := "", "", "" if input.Name != "" { @@ -132,6 +115,7 @@ func UpdateModelService(ctx context.Context, c dynamic.Interface, input generate if err != nil { return nil, err } + updatedEmbedder, err := embedder.UpdateEmbedder(ctx, c, name, namespace, displayName) if err != nil { return nil, err @@ -156,180 +140,99 @@ func UpdateModelService(ctx context.Context, c dynamic.Interface, input generate Annotations: input.Annotations, Types: input.Types, APIType: input.APIType, - EmbedderResource: updatedEmbedder, - LlmResource: updatedLLM, CreationTimestamp: creationTimestamp, UpdateTimestamp: updateTimestamp, } return ds, nil } +// DeleteModelService deletes a 3rd_party model service func DeleteModelService(ctx context.Context, c dynamic.Interface, input *generated.DeleteCommonInput) (*string, error) { - var errText string - _, err1 := embedder.DeleteEmbedders(ctx, c, input) - if err1 != nil { - errText += "embedder: " + err1.Error() - } - _, err2 := llm.DeleteLLMs(ctx, c, input) - if err2 != nil { - errText += " llm:" + err2.Error() + _, err := embedder.DeleteEmbedders(ctx, c, input) + if err != nil { + return nil, err } - if errText != "" { - return nil, errors.New("error occurred during deleting: " + errText) + _, err = llm.DeleteLLMs(ctx, c, input) + if err != nil { + return nil, err } return nil, nil } -var ( - fixedPage = 1 - // because the data is to be paged, no parameters are provided, - // and the default return is 10. Getting modelserve needs to get all the llms and embedding, - // so a larger pageSize is provided here. - fixedPageSize = 100 -) +// GetModelService get a 3rd_party model service +func ReadModelService(ctx context.Context, c dynamic.Interface, name string, namespace string) (*generated.ModelService, error) { + var modelService = &generated.ModelService{} -const ( - modelTypeAll = "all" - modelTypeLLM = "llm" - modelTypeEmbedding = "embedding" -) - -func debugModelService(m *generated.ModelService) string { - id := "" - if m.ID != nil { - id = *m.ID + llm, err := llm.ReadLLM(ctx, c, name, namespace) + if err == nil { + modelService = LLM2ModelService(llm) } - - creator := "" - if m.Creator != nil { - creator = *m.Creator + embedder, err := embedder.ReadEmbedder(ctx, c, name, namespace) + if err == nil { + modelService = Embedder2ModelService(embedder) } - types := "" - if m.Types != nil { - types = *m.Types + + if llm != nil && embedder != nil { + modelService.Types = &common.ModelTypeAll } - return fmt.Sprintf("{id: %s, creator: %s, types: %s, apiType: %s, creationTimestamp: %s, updateTimestamp: %s}", - id, creator, types, *m.APIType, m.CreationTimestamp, m.UpdateTimestamp) + + return modelService, nil } -func ListModelServices(ctx context.Context, c dynamic.Interface, input *generated.ListModelService) (*generated.PaginatedResult, error) { - var ( - llmList, embedderList, workerList *generated.PaginatedResult - err error - ) - query := generated.ListCommonInput{Page: &fixedPage, PageSize: &fixedPageSize, Namespace: input.Namespace, Keyword: input.Keyword} +// ListModelServices based on input +func ListModelServices(ctx context.Context, c dynamic.Interface, input *generated.ListModelServiceInput) (*generated.PaginatedResult, error) { + // use `UnlimitedPageSize` so we can get all llms and embeddings + query := generated.ListCommonInput{Page: input.Page, PageSize: &common.UnlimitedPageSize, Namespace: input.Namespace, Keyword: input.Keyword} // list all llms - llmList, err = llm.ListLLMs(ctx, c, query) + llmList, err := llm.ListLLMs(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode { + llm, ok := a.(*generated.Llm) + if !ok { + return nil + } + // convert llm to modelserivce + return LLM2ModelService(llm) + })) if err != nil { - klog.Errorf("failed to list llm %s", err) return nil, err } // list all embedders - embedderList, err = embedder.ListEmbedders(ctx, c, query) - if err != nil { - klog.Errorf("failed to list embedder %s", err) - - return nil, err - } - - // list all workers - workerList, err = worker.ListWorkers(ctx, c, generated.ListWorkerInput{Namespace: input.Namespace, Page: &fixedPage, PageSize: &fixedPageSize}) + embedderList, err := embedder.ListEmbedders(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode { + embedder, ok := a.(*generated.Embedder) + if !ok { + return nil + } + // convert embedder to modelserivce + return Embedder2ModelService(embedder) + })) if err != nil { - klog.Errorf("failed to list worker %s", err) return nil, err } - workerResource := make(map[string]*generated.Worker) - for _, item := range workerList.Nodes { - v := item.(*generated.Worker) - workerResource[v.Name] = v - } - - modelServiceList := make([]*generated.ModelService, 0) - intersec := make(map[string]*generated.ModelService) - klog.V(5).Infof("namespace: %s modetype: %s, providerType: %s, apiType: %s", - input.Namespace, input.ModelType, input.ProviderType, input.APIType) - - // The overall idea is to get the llm-filtered list first when not filtering on modelType or when you want to filter on llm. - // Or to filter embedder, first get the list of embedder. - if input.ModelType == "" || input.ModelType == modelTypeAll || input.ModelType == modelTypeLLM { - for idx := range llmList.Nodes { - v := llmList.Nodes[idx].(*generated.Llm) - klog.V(5).Infof("add llm modelservice: llm %s, type: %s, provider: %s. filter apiType: %s, filter providers: %s", - v.Name, *v.Type, *v.Provider, input.APIType, input.ProviderType) - if input.APIType != nil && *v.Type != *input.APIType { - continue - } - if input.ProviderType != nil && *v.Provider != *input.ProviderType { - continue - } - - ms := &generated.ModelService{ - Name: v.Name, - Namespace: input.Namespace, - Creator: v.Creator, - Description: v.Description, - Types: new(string), - CreationTimestamp: v.CreationTimestamp, - UpdateTimestamp: v.UpdateTimestamp, - APIType: v.Type, - LlmResource: v, - EmbedderResource: new(generated.Embedder), - Resource: new(generated.Resources), - } - *ms.Types = modelTypeLLM - intersec[v.Name] = ms - - modelServiceList = append(modelServiceList, ms) - klog.V(5).Infof("add llm modelservice: append only llm modelService to list: %s", debugModelService(ms)) - if r, ok := workerResource[v.Name]; ok { - klog.V(5).Infof(" set llm modelservice: %s resource: set modelservice resource: %+v", v.Name, r) - *ms.Resource = r.Resources - ms.ID = r.ID - } + // serviceList keeps all model services with combined + serviceMapList := make(map[string]*generated.ModelService) + for _, node := range append(llmList.Nodes, embedderList.Nodes...) { + ms, _ := node.(*generated.ModelService) + _, ok := serviceMapList[ms.Name] + // if llm & embedder has same name,we treat it as `ModelTypeAll(llm,embedding)` + if ok { + ms.Types = &common.ModelTypeAll } - } else { - for idx := range embedderList.Nodes { - v := embedderList.Nodes[idx].(*generated.Embedder) - klog.V(5).Infof("add embedder modelservice: embedder %s, type: %s, provider: %s. filter apiType: %s, filter providers: %s", - v.Name, *v.Type, *v.Provider, input.APIType, input.ProviderType) - - if input.APIType != nil && *v.Type != *input.APIType { - continue - } - if input.ProviderType != nil && *v.Provider != *input.ProviderType { - continue - } - - ms := &generated.ModelService{ - Name: v.Name, - Namespace: input.Namespace, - Creator: v.Creator, - Description: v.Description, - Types: new(string), - CreationTimestamp: v.CreationTimestamp, - UpdateTimestamp: v.UpdateTimestamp, - APIType: v.Type, - EmbedderResource: v, - LlmResource: new(generated.Llm), - Resource: new(generated.Resources), - } - *ms.Types = modelTypeEmbedding - - intersec[v.Name] = ms - modelServiceList = append(modelServiceList, ms) - klog.V(5).Infof("add embedder modelservice: append only embedder modelService to list: %s", debugModelService(ms)) + serviceMapList[ms.Name] = ms + } - if r, ok := workerResource[v.Name]; ok { - klog.V(5).Infof("set embedder modelservice: %s resource: %+v", v.Name, r) - *ms.Resource = r.Resources - ms.ID = r.ID - } - } + // newNodeList + newNodeList := make([]*generated.ModelService, 0, len(serviceMapList)) + for _, node := range serviceMapList { + newNodeList = append(newNodeList, node) } + // sort by creation timestamp + sort.Slice(newNodeList, func(i, j int) bool { + return newNodeList[i].CreationTimestamp.After(*newNodeList[j].CreationTimestamp) + }) + // return ModelService with the actual Page and PageSize page, pageSize := 1, 10 if input.Page != nil && *input.Page > 0 { page = *input.Page @@ -338,213 +241,55 @@ func ListModelServices(ctx context.Context, c dynamic.Interface, input *generate pageSize = *input.PageSize } - // If we are filtering llm or embedding, - // if the list obtained earlier is empty (e.g. llmList), then we don't need to judge the embeddings anymore. - if len(modelServiceList) == 0 && input.ModelType != "" && input.ModelType != modelTypeAll { - return &generated.PaginatedResult{ - HasNextPage: false, - Nodes: []generated.PageNode{}, - TotalCount: 0, - }, nil - } + var totalCount int - // After getting the list of llm's, - // we need to determine if any embedder meets the filter condition and if any embedder can be merged into the modeservice of the llm. - // If an embedder meets the filter condition, we need to determine whether to create a new modelservice or merge it into the modelservice of llm. - // The following logic is the same. - if input.ModelType == "" || input.ModelType == modelTypeAll || input.ModelType == modelTypeLLM { - for idx := range embedderList.Nodes { - v := embedderList.Nodes[idx].(*generated.Embedder) - if input.APIType != nil && *v.Type != *input.APIType { - continue - } - if input.ProviderType != nil && *v.Provider != *input.ProviderType { - continue - } - llm, ok := intersec[v.Name] - if !ok && input.ModelType != "" && input.ModelType != modelTypeAll { - continue - } - if !ok || *llm.APIType != *v.Type || *llm.LlmResource.Provider != *v.Provider { - if !ok { - klog.V(5).Infof("match check llm: embedder %s has no matching llm's, add modelservice", v.Name) - } - if ok && (*llm.APIType != *v.Type || *llm.LlmResource.Provider != *v.Provider) { - klog.V(5).Infof("match check llm: embedder %s type: %s, llm apiType: %s, llm provider: %s, embedder provider: %s. add modelservice", - v.Name, *v.Type, *llm.APIType, *llm.LlmResource.Provider, *v.Provider) - } - - ms := &generated.ModelService{ - Name: v.Name, - Namespace: input.Namespace, - Creator: v.Creator, - Description: v.Description, - Types: new(string), - CreationTimestamp: v.CreationTimestamp, - UpdateTimestamp: v.UpdateTimestamp, - APIType: v.Type, - LlmResource: new(generated.Llm), - EmbedderResource: v, - Resource: new(generated.Resources), - } - *ms.Types = modelTypeEmbedding - if r, ok := workerResource[v.Name]; ok { - klog.V(5).Infof("match check llm: set modelservice %s resource: %+v", v.Name, r) - *ms.Resource = r.Resources - ms.ID = r.ID - } - - klog.V(5).Infof("match check llm: append only embedder modelService to list: %s", debugModelService(ms)) - modelServiceList = append(modelServiceList, ms) - continue - } + result := make([]generated.PageNode, 0, pageSize) + pageStart := (page - 1) * pageSize - *llm.Types = modelTypeLLM + "," + modelTypeEmbedding - llm.EmbedderResource = v - if llm.CreationTimestamp.After(*v.CreationTimestamp) { - llm.CreationTimestamp = v.CreationTimestamp - } - if llm.UpdateTimestamp.Before(*v.UpdateTimestamp) { - llm.UpdateTimestamp = v.UpdateTimestamp - } - klog.V(5).Infof("match check llm: embedder match llm %s", debugModelService(llm)) - } - } else { - for idx := range llmList.Nodes { - v := llmList.Nodes[idx].(*generated.Llm) - if input.APIType != nil && *v.Type != *input.APIType { - continue - } - if input.ProviderType != nil && *v.Provider != *input.ProviderType { + // index is the actual result length + var index int + for _, service := range newNodeList { + // Add filter conditions here + // 1. filter service types: llm or embedding or both + if input.Types != nil && *input.Types != "" { + if !strings.Contains(*service.Types, *input.Types) { continue } - - embedder, ok := intersec[v.Name] - if !ok && input.ModelType != "" && input.ModelType != modelTypeAll { + } + // 2. filter provider type: worker or 3rd_party + if input.ProviderType != nil && *input.ProviderType != "" { + if service.ProviderType == nil || *service.ProviderType != *input.ProviderType { continue } - - if !ok || embedder.Name != v.Name || *embedder.APIType != *v.Type || *embedder.EmbedderResource.Provider != *v.Provider { - if !ok { - klog.V(5).Infof("match check embedder: llm %s has no matching embedder's, add modelservice", v.Name) - } - if ok && (*embedder.APIType != *v.Type || *embedder.EmbedderResource.Provider != *v.Provider) { - klog.V(5).Infof("match check embedder: llm %s type: %s, embedder apiType: %s, embedder provider: %s, llm provider: %s. add modelservice", - v.Name, *v.Type, *embedder.APIType, *embedder.EmbedderResource.Provider, *v.Provider) - } - - ms := &generated.ModelService{ - Name: v.Name, - Namespace: input.Namespace, - Creator: v.Creator, - Description: v.Description, - Types: new(string), - CreationTimestamp: v.CreationTimestamp, - UpdateTimestamp: v.UpdateTimestamp, - APIType: v.Type, - LlmResource: v, - EmbedderResource: new(generated.Embedder), - Resource: new(generated.Resources), - } - *ms.Types = modelTypeLLM - if r, ok := workerResource[v.Name]; ok { - klog.V(5).Infof("match check embedder: set modelservice %s resource: %+v", v.Name, r) - *ms.Resource = r.Resources - ms.ID = v.ID - } - klog.V(5).Infof("match check embedder: append only llm modelService to list: %s", debugModelService(ms)) - modelServiceList = append(modelServiceList, ms) + } + // 3. filter api type: openai or zhipuai + if input.APIType != nil && *input.APIType != "" { + if service.APIType == nil || *service.APIType != *input.APIType { continue } - - *embedder.Types = modelTypeLLM + "," + modelTypeEmbedding - embedder.LlmResource = v - if embedder.CreationTimestamp.After(*v.CreationTimestamp) { - embedder.CreationTimestamp = v.CreationTimestamp - } - if embedder.UpdateTimestamp.Before(*v.UpdateTimestamp) { - embedder.UpdateTimestamp = v.UpdateTimestamp - } - klog.V(5).Infof("match check llm: embedder match llm %s", debugModelService(embedder)) } - } - start := (page - 1) * pageSize - end := start + pageSize - total := len(modelServiceList) + // increase totalCount when service meets the filter conditions + totalCount++ - result := pageModelService(start, end, &modelServiceList) - nodes := make([]generated.PageNode, len(result)) - for idx := range result { - nodes[idx] = result[idx] - } - return &generated.PaginatedResult{ - HasNextPage: end < total, - Nodes: nodes, - Page: &page, - PageSize: &pageSize, - TotalCount: total, - }, nil -} - -func GetModelService(ctx context.Context, c dynamic.Interface, name, namespace, apiType string) (*generated.ModelService, error) { - ms := &generated.ModelService{ - ID: new(string), - Name: name, - Namespace: namespace, - Creator: new(string), - Description: new(string), - Types: new(string), - CreationTimestamp: new(time.Time), - UpdateTimestamp: new(time.Time), - APIType: &apiType, - LlmResource: new(generated.Llm), - EmbedderResource: new(generated.Embedder), - Resource: new(generated.Resources), - } - exist := false - if r1, err := llm.ReadLLM(ctx, c, name, namespace); err == nil && *r1.Type == apiType { - exist = true - if r1.CreationTimestamp != nil { - *ms.CreationTimestamp = *r1.CreationTimestamp + // append result + if index >= pageStart && len(result) < pageSize { + result = append(result, service) } - if r1.UpdateTimestamp != nil { - *ms.UpdateTimestamp = *r1.UpdateTimestamp - } - *ms.Types = "llm" - if r1.Description != nil { - *ms.Description = *r1.Description - } - *ms.LlmResource = *r1 - } - if r2, err := embedder.ReadEmbedder(ctx, c, name, namespace); err == nil && *r2.Type == apiType { - exist = true - if r2.CreationTimestamp != nil && (ms.CreationTimestamp == nil || r2.CreationTimestamp.Before(*ms.CreationTimestamp)) { - *ms.CreationTimestamp = *r2.CreationTimestamp - } - if r2.UpdateTimestamp != nil && (ms.UpdateTimestamp == nil || r2.UpdateTimestamp.After(*ms.UpdateTimestamp)) { - *ms.UpdateTimestamp = *r2.UpdateTimestamp - } - if *ms.Description == "" && r2.Description != nil { - *ms.Description = *r2.Description - } - if *ms.Types == modelTypeLLM { - *ms.Types += "," + modelTypeEmbedding - } else { - *ms.Types = modelTypeEmbedding - } - *ms.EmbedderResource = *r2 + index++ } - if r3, err := worker.ReadWorker(ctx, c, name, namespace); err == nil { - *ms.Resource = r3.Resources - *ms.ID = *r3.ID - } - if !exist { - return nil, fmt.Errorf("not found modelService %s", name) + end := page * pageSize + if end > totalCount { + end = totalCount } - return ms, nil + + return &generated.PaginatedResult{ + TotalCount: totalCount, + HasNextPage: end < totalCount, + Nodes: result, + }, nil } var ( diff --git a/apiserver/pkg/modelservice/sort.go b/apiserver/pkg/modelservice/sort.go deleted file mode 100644 index 8f08d5340..000000000 --- a/apiserver/pkg/modelservice/sort.go +++ /dev/null @@ -1,72 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package modelservice - -import ( - "container/heap" - - "github.com/kubeagi/arcadia/apiserver/graph/generated" -) - -// Because here the merger of llm and embedders two kinds of data, -// so the amount of data than one kind of quite a bit more, -// so the direct sort may not be an optimal solution, -// we can use the heap, and start, end the two parameters do not arrange all the data, and then find the data to be paged. -type ModelServiceList []*generated.ModelService - -func (m *ModelServiceList) Len() int { - return len(*m) -} - -func (m *ModelServiceList) Less(i, j int) bool { - a := (*m)[i] - b := (*m)[j] - if a.CreationTimestamp.Equal(*b.CreationTimestamp) { - return a.Name < b.Name - } - return a.CreationTimestamp.After(*b.CreationTimestamp) -} - -func (m *ModelServiceList) Swap(i, j int) { - (*m)[i], (*m)[j] = (*m)[j], (*m)[i] -} - -func (m *ModelServiceList) Push(x any) { - *m = append(*m, x.(*generated.ModelService)) -} - -func (m *ModelServiceList) Pop() any { - old := *m - l := len(old) - x := old[l-1] - *m = old[:l-1] - return x -} - -func pageModelService(start, end int, list *[]*generated.ModelService) []*generated.ModelService { - l := ModelServiceList(*list) - heap.Init(&l) - - r := make([]*generated.ModelService, 0) - for cur := 0; cur < end && l.Len() > 0; cur++ { - top := heap.Pop(&l) - if cur < start { - continue - } - r = append(r, top.(*generated.ModelService)) - } - return r -} diff --git a/apiserver/pkg/modelservice/sort_test.go b/apiserver/pkg/modelservice/sort_test.go deleted file mode 100644 index de9657d3e..000000000 --- a/apiserver/pkg/modelservice/sort_test.go +++ /dev/null @@ -1,76 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package modelservice - -import ( - "testing" - "time" - - "github.com/kubeagi/arcadia/apiserver/graph/generated" -) - -type input struct { - start, end int - exp []string -} - -func initData() []*generated.ModelService { - seconds := []int{ - 0, -2, 4, -6, 8, -10, 12, -14, 16, -18, 20, -22, 24, 26, 28, 30, - } - // sorted order: 30,28,26,24,33, 66, 5, 4, 7, 3, 1, 2, 6, 10, 55, 88 - source := []*generated.ModelService{ - {Name: "3"}, {Name: "1"}, {Name: "7"}, - {Name: "2"}, {Name: "4"}, {Name: "6"}, - {Name: "5"}, {Name: "10"}, {Name: "66"}, - {Name: "55"}, {Name: "33"}, {Name: "88"}, - {Name: "24"}, {Name: "26"}, {Name: "28"}, - {Name: "30"}, - } - now := time.Now() - for idx := range source { - tmp := now.Add(time.Second * time.Duration(seconds[idx])) - source[idx].CreationTimestamp = &tmp - } - return source -} - -func TestPagedModelService(t *testing.T) { - for _, tc := range []input{ - {0, 3, []string{"30", "28", "26"}}, - {3, 6, []string{"24", "33", "66"}}, - {6, 9, []string{"5", "4", "7"}}, - {9, 12, []string{"3", "1", "2"}}, - {12, 15, []string{"6", "10", "55"}}, - {15, 18, []string{"88"}}, - {0, 7, []string{"30", "28", "26", "24", "33", "66", "5"}}, - {7, 14, []string{"4", "7", "3", "1", "2", "6", "10"}}, - {14, 18, []string{"55", "88"}}, - {0, 10, []string{"30", "28", "26", "24", "33", "66", "5", "4", "7", "3"}}, - {10, 20, []string{"1", "2", "6", "10", "55", "88"}}, - } { - source := initData() - r := pageModelService(tc.start, tc.end, &source) - if len(tc.exp) != len(r) { - t.Fatalf("expect %d items, got %d", len(tc.exp), len(r)) - } - for idx := range r { - if r[idx].Name != tc.exp[idx] { - t.Fatalf("[%d - %d] expects the i-th to be %s but actually is %s", tc.start, tc.end, tc.exp[idx], r[idx].Name) - } - } - } -} diff --git a/apiserver/pkg/worker/worker.go b/apiserver/pkg/worker/worker.go index 156f1aa69..6a64041e6 100644 --- a/apiserver/pkg/worker/worker.go +++ b/apiserver/pkg/worker/worker.go @@ -45,7 +45,8 @@ const ( NvidiaGPU = "nvidia.com/gpu" ) -func worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) *generated.Worker { +// Worker2model convert unstructured `CR Worker` to graphql model +func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) *generated.Worker { worker := &v1alpha1.Worker{} if err := utils.UnstructuredToStructured(obj, worker); err != nil { return &generated.Worker{} @@ -61,6 +62,7 @@ func worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un // Unknown,Pending ,Running ,Error status := common.GetObjStatus(worker) + message := condition.Message // replicas var replicas string @@ -97,6 +99,7 @@ func worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un Description: &worker.Spec.Description, Type: &workerType, Status: &status, + Message: &message, CreationTimestamp: &creationtimestamp, UpdateTimestamp: &updateTime, Replicas: &replicas, @@ -194,7 +197,7 @@ func CreateWorker(ctx context.Context, c dynamic.Interface, input generated.Crea if err != nil { return nil, err } - return worker2model(ctx, c, obj), nil + return Worker2model(ctx, c, obj), nil } func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.UpdateWorkerInput) (*generated.Worker, error) { @@ -267,7 +270,7 @@ func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.Upd return nil, err } - return worker2model(ctx, c, updatedObject), nil + return Worker2model(ctx, c, updatedObject), nil } func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.DeleteCommonInput) (*string, error) { @@ -301,7 +304,12 @@ func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.De return nil, nil } -func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListWorkerInput) (*generated.PaginatedResult, error) { +func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListWorkerInput, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) { + opts := common.DefaultListOptions() + for _, optFunc := range listOpts { + optFunc(opts) + } + keyword, modelTypes, labelSelector, fieldSelector := "", "", "", "" page, pageSize := 1, 10 if input.Keyword != nil { @@ -338,6 +346,12 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW totalCount := len(us.Items) + // if pageSize is -1 which means unlimited pagesize,return all + if pageSize == common.UnlimitedPageSize { + page = 1 + pageSize = totalCount + } + result := make([]generated.PageNode, 0, pageSize) pageStart := (page - 1) * pageSize for index, u := range us.Items { @@ -345,7 +359,7 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW if index < pageStart { continue } - m := worker2model(ctx, c, &u) + m := Worker2model(ctx, c, &u) // filter based on `keyword` if keyword != "" { if !strings.Contains(m.Name, keyword) && !strings.Contains(*m.DisplayName, keyword) { @@ -358,7 +372,7 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW } } - result = append(result, m) + result = append(result, opts.ConvertFunc(m)) // break if page size matches if len(result) == pageSize { @@ -388,5 +402,5 @@ func ReadWorker(ctx context.Context, c dynamic.Interface, name, namespace string if err != nil { return nil, err } - return worker2model(ctx, c, u), nil + return Worker2model(ctx, c, u), nil }