From 058cf4633e627d441374d9d7b0e63fedff67b2e1 Mon Sep 17 00:00:00 2001 From: 0xff-dev Date: Fri, 19 Jan 2024 14:48:32 +0800 Subject: [PATCH] fix: too long to call listModelServices --- api/base/v1alpha1/common.go | 2 + apiserver/graph/impl/worker.resolvers.go | 2 +- apiserver/pkg/common/common_list_resource.go | 3 +- apiserver/pkg/modelservice/helper.go | 26 +++ apiserver/pkg/modelservice/modelservice.go | 166 +++++++++---------- apiserver/pkg/worker/worker.go | 53 ++++-- controllers/base/embedder_controller.go | 14 ++ controllers/base/llm_controller.go | 12 ++ 8 files changed, 178 insertions(+), 100 deletions(-) diff --git a/api/base/v1alpha1/common.go b/api/base/v1alpha1/common.go index 30f8c9182..4ed2b9a2b 100644 --- a/api/base/v1alpha1/common.go +++ b/api/base/v1alpha1/common.go @@ -49,6 +49,8 @@ const ( ProviderTypeUnknown ProviderType = "unknown" ProviderType3rdParty ProviderType = "3rd_party" ProviderTypeWorker ProviderType = "worker" + + ProviderLabel = Group + "/provider" ) // Provider defines how to prvoide the service diff --git a/apiserver/graph/impl/worker.resolvers.go b/apiserver/graph/impl/worker.resolvers.go index 582643dc4..d17d8fc41 100644 --- a/apiserver/graph/impl/worker.resolvers.go +++ b/apiserver/graph/impl/worker.resolvers.go @@ -66,7 +66,7 @@ func (r *workerQueryResolver) ListWorkers(ctx context.Context, obj *generated.Wo return nil, err } - return md.ListWorkers(ctx, c, input) + return md.ListWorkers(ctx, c, input, true) } // WorkerMutation returns generated.WorkerMutationResolver implementation. diff --git a/apiserver/pkg/common/common_list_resource.go b/apiserver/pkg/common/common_list_resource.go index 37d503543..8e9c21e0e 100644 --- a/apiserver/pkg/common/common_list_resource.go +++ b/apiserver/pkg/common/common_list_resource.go @@ -28,8 +28,7 @@ import ( func ListReources(list *unstructured.UnstructuredList, page, pageSize int, converter ResourceConverter, options ...ResourceFilter) (*generated.PaginatedResult, error) { index, optIndex := 0, 0 for i := range list.Items { - optIndex = 0 - for ; optIndex < len(options); optIndex++ { + for optIndex = 0; optIndex < len(options); optIndex++ { if !options[optIndex](&list.Items[i]) { break } diff --git a/apiserver/pkg/modelservice/helper.go b/apiserver/pkg/modelservice/helper.go index 0d44f13a3..41e28ff75 100644 --- a/apiserver/pkg/modelservice/helper.go +++ b/apiserver/pkg/modelservice/helper.go @@ -19,6 +19,7 @@ package modelservice import ( "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/common" + "github.com/kubeagi/arcadia/pkg/llms" ) // Embedder2ModelService convert unstructured `CR Embedder` to graphql model `ModelService` @@ -80,3 +81,28 @@ func LLM2ModelService(llm *generated.Llm) *generated.ModelService { } return ms } + +func Worker2ModelService(worker *generated.Worker) *generated.ModelService { + ms := &generated.ModelService{ + ID: worker.ID, + Name: worker.Name, + Namespace: worker.Namespace, + CreationTimestamp: worker.CreationTimestamp, + UpdateTimestamp: worker.UpdateTimestamp, + Creator: worker.Creator, + DisplayName: worker.DisplayName, + Description: worker.Description, + ProviderType: new(string), + Types: new(string), + APIType: new(string), + EmbeddingModels: []string{*worker.ID}, + LlmModels: []string{*worker.ID}, + Status: worker.Status, + Message: worker.Message, + } + + *ms.ProviderType = "worker" + *ms.Types = common.ModelTypeAll + *ms.APIType = string(llms.OpenAI) + return ms +} diff --git a/apiserver/pkg/modelservice/modelservice.go b/apiserver/pkg/modelservice/modelservice.go index 0a170ce11..84e402f61 100644 --- a/apiserver/pkg/modelservice/modelservice.go +++ b/apiserver/pkg/modelservice/modelservice.go @@ -27,10 +27,13 @@ import ( "github.com/tmc/langchaingo/llms" "k8s.io/client-go/dynamic" + "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/apiserver/graph/generated" "github.com/kubeagi/arcadia/apiserver/pkg/common" "github.com/kubeagi/arcadia/apiserver/pkg/embedder" "github.com/kubeagi/arcadia/apiserver/pkg/llm" + "github.com/kubeagi/arcadia/apiserver/pkg/worker" + llmspkg "github.com/kubeagi/arcadia/pkg/llms" "github.com/kubeagi/arcadia/pkg/llms/openai" "github.com/kubeagi/arcadia/pkg/llms/zhipuai" ) @@ -248,57 +251,91 @@ func ReadModelService(ctx context.Context, c dynamic.Interface, name string, nam // ListModelServices based on input func ListModelServices(ctx context.Context, c dynamic.Interface, input *generated.ListModelServiceInput) (*generated.PaginatedResult, error) { // use `UnlimitedPageSize` so we can get all llms and embeddings - query := generated.ListCommonInput{Page: input.Page, PageSize: &common.UnlimitedPageSize, Namespace: input.Namespace, Keyword: input.Keyword} + notWorkerSelector := fmt.Sprintf("%s=%s", v1alpha1.ProviderLabel, v1alpha1.ProviderType3rdParty) - // list all llms - llmList, err := llm.ListLLMs(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode { - llm, ok := a.(*generated.Llm) - if !ok { - return nil + query := generated.ListCommonInput{ + Page: input.Page, + PageSize: &common.UnlimitedPageSize, + Namespace: input.Namespace, + Keyword: input.Keyword, + LabelSelector: ¬WorkerSelector, + } + + newNodeList := make([]generated.PageNode, 0) + + if input.ProviderType == nil || *input.ProviderType == string(v1alpha1.ProviderType3rdParty) { + exist := make(map[string]int) + if input.Types == nil || (*input.Types == common.ModelTypeAll || *input.Types == common.ModelTypeLLM) { + llmList, err := llm.ListLLMs(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode { + llm, ok := a.(*generated.Llm) + if !ok { + return nil + } + return LLM2ModelService(llm) + })) + if err != nil { + return nil, err + } + for _, n := range llmList.Nodes { + tmp := n.(*generated.ModelService) + if input.APIType != nil && *input.APIType != *tmp.APIType { + continue + } + newNodeList = append(newNodeList, tmp) + exist[tmp.Name] = len(newNodeList) - 1 + } } - // convert llm to modelserivce - return LLM2ModelService(llm) - })) - if err != nil { - return nil, err - } - // list all embedders - embedderList, err := embedder.ListEmbedders(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode { - embedder, ok := a.(*generated.Embedder) - if !ok { - return nil + if input.Types == nil || (*input.Types == common.ModelTypeAll || *input.Types == common.ModelTypeEmbedding) { + embedderList, err := embedder.ListEmbedders(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode { + embedder, ok := a.(*generated.Embedder) + if !ok { + return nil + } + return Embedder2ModelService(embedder) + })) + if err != nil { + return nil, err + } + for _, n := range embedderList.Nodes { + tmp := n.(*generated.ModelService) + if input.APIType != nil && *input.APIType != *tmp.APIType { + continue + } + if idx, ok := exist[tmp.Name]; ok { + t := newNodeList[idx].(*generated.ModelService) + t.Types = &common.ModelTypeAll + t.EmbeddingModels = tmp.EmbeddingModels + continue + } + newNodeList = append(newNodeList, tmp) + } } - // convert embedder to modelserivce - return Embedder2ModelService(embedder) - })) - if err != nil { - return nil, err } - // serviceList keeps all model services with combined - serviceMapList := make(map[string]*generated.ModelService) - for _, node := range append(llmList.Nodes, embedderList.Nodes...) { - ms, _ := node.(*generated.ModelService) - curr, ok := serviceMapList[ms.Name] - // if llm & embedder has same name,we treat it as `ModelTypeAll(llm,embedding)` - if ok { - ms.Types = &common.ModelTypeAll - // combine models provided by this model service - ms.LlmModels = append(ms.LlmModels, curr.LlmModels...) - ms.EmbeddingModels = append(ms.EmbeddingModels, curr.EmbeddingModels...) + if input.ProviderType == nil || *input.ProviderType == string(v1alpha1.ProviderTypeWorker) { + if input.APIType == nil || *input.APIType == string(llmspkg.OpenAI) { + workerQuery := generated.ListWorkerInput{ + Page: query.Page, + PageSize: &common.UnlimitedPageSize, + Namespace: input.Namespace, + Keyword: input.Keyword, + } + workerList, err := worker.ListWorkers(ctx, c, workerQuery, false) + if err != nil { + return nil, err + } + for _, n := range workerList.Nodes { + tmp := n.(*generated.Worker) + newNodeList = append(newNodeList, Worker2ModelService(tmp)) + } } - serviceMapList[ms.Name] = ms } - // newNodeList - newNodeList := make([]*generated.ModelService, 0, len(serviceMapList)) - for _, node := range serviceMapList { - newNodeList = append(newNodeList, node) - } - // sort by creation timestamp sort.Slice(newNodeList, func(i, j int) bool { - return newNodeList[i].CreationTimestamp.After(*newNodeList[j].CreationTimestamp) + a := newNodeList[i].(*generated.ModelService) + b := newNodeList[j].(*generated.ModelService) + return a.CreationTimestamp.After(*b.CreationTimestamp) }) // return ModelService with the actual Page and PageSize @@ -310,54 +347,13 @@ func ListModelServices(ctx context.Context, c dynamic.Interface, input *generate pageSize = *input.PageSize } - var totalCount int - - result := make([]generated.PageNode, 0, pageSize) - pageStart := (page - 1) * pageSize - - // index is the actual result length - var index int - for _, service := range newNodeList { - // Add filter conditions here - // 1. filter service types: llm or embedding or both - if input.Types != nil && *input.Types != "" { - if !strings.Contains(*service.Types, *input.Types) { - continue - } - } - // 2. filter provider type: worker or 3rd_party - if input.ProviderType != nil && *input.ProviderType != "" { - if service.ProviderType == nil || *service.ProviderType != *input.ProviderType { - continue - } - } - // 3. filter api type: openai or zhipuai - if input.APIType != nil && *input.APIType != "" { - if service.APIType == nil || *service.APIType != *input.APIType { - continue - } - } - - // increase totalCount when service meets the filter conditions - totalCount++ - - // append result - if index >= pageStart && len(result) < pageSize { - result = append(result, service) - } - - index++ - } - - end := page * pageSize - if end > totalCount { - end = totalCount - } + totalCount := len(newNodeList) + start, end := common.PagePosition(page, pageSize, totalCount) return &generated.PaginatedResult{ TotalCount: totalCount, HasNextPage: end < totalCount, - Nodes: result, + Nodes: newNodeList[start:end], }, nil } diff --git a/apiserver/pkg/worker/worker.go b/apiserver/pkg/worker/worker.go index b72a37605..52ed88a1f 100644 --- a/apiserver/pkg/worker/worker.go +++ b/apiserver/pkg/worker/worker.go @@ -43,14 +43,14 @@ const ( NvidiaGPU = "nvidia.com/gpu" ) -func worker2modelConverter(ctx context.Context, c dynamic.Interface) func(*unstructured.Unstructured) (generated.PageNode, error) { +func worker2modelConverter(ctx context.Context, c dynamic.Interface, showModel bool) func(*unstructured.Unstructured) (generated.PageNode, error) { return func(u *unstructured.Unstructured) (generated.PageNode, error) { - return Worker2model(ctx, c, u) + return Worker2model(ctx, c, u, showModel) } } // Worker2model convert unstructured `CR Worker` to graphql model -func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) (*generated.Worker, error) { +func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured, showModel bool) (*generated.Worker, error) { worker := &v1alpha1.Worker{} if err := utils.UnstructuredToStructured(obj, worker); err != nil { return nil, err @@ -107,8 +107,6 @@ func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un workerType := string(worker.Type()) - api, _ := common.GetAPIServer(ctx, c, true) - // wrap Worker w := generated.Worker{ ID: &id, @@ -129,11 +127,11 @@ func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un MatchExpressions: matchExpressions, AdditionalEnvs: additionalEnvs, ModelTypes: "unknown", - API: &api, + API: new(string), } // read worker's models - if worker.Spec.Model != nil { + if worker.Spec.Model != nil && showModel { typedModel := worker.Model() model, err := gqlmodel.ReadModel(ctx, c, typedModel.Name, *typedModel.Namespace) if err != nil { @@ -247,7 +245,15 @@ func CreateWorker(ctx context.Context, c dynamic.Interface, input generated.Crea if err != nil { return nil, err } - return Worker2model(ctx, c, obj) + + api, _ := common.GetAPIServer(ctx, c, true) + + w, err := Worker2model(ctx, c, obj, true) + if err != nil { + return nil, err + } + *w.API = api + return w, nil } func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.UpdateWorkerInput) (*generated.Worker, error) { @@ -344,8 +350,14 @@ func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.Upd if err != nil { return nil, err } + api, _ := common.GetAPIServer(ctx, c, true) - return Worker2model(ctx, c, updatedObject) + w, err := Worker2model(ctx, c, updatedObject, true) + if err != nil { + return nil, err + } + *w.API = api + return w, nil } func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.DeleteCommonInput) (*string, error) { @@ -379,7 +391,7 @@ func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.De return nil, nil } -func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListWorkerInput, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) { +func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListWorkerInput, showWorkerModel bool, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) { opts := common.DefaultListOptions() for _, optFunc := range listOpts { optFunc(opts) @@ -415,7 +427,17 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW if err != nil { return nil, err } - return common.ListReources(us, page, pageSize, worker2modelConverter(ctx, c), filter...) + list, err := common.ListReources(us, page, pageSize, worker2modelConverter(ctx, c, showWorkerModel), filter...) + if err != nil { + return nil, err + } + api, _ := common.GetAPIServer(ctx, c, true) + + for i := range list.Nodes { + tmp := list.Nodes[i].(*generated.Worker) + *tmp.API = api + } + return list, nil } func ReadWorker(ctx context.Context, c dynamic.Interface, name, namespace string) (*generated.Worker, error) { @@ -428,5 +450,12 @@ func ReadWorker(ctx context.Context, c dynamic.Interface, name, namespace string if err != nil { return nil, err } - return Worker2model(ctx, c, u) + api, _ := common.GetAPIServer(ctx, c, true) + + w, err := Worker2model(ctx, c, u, true) + if err != nil { + return nil, err + } + *w.API = api + return w, nil } diff --git a/controllers/base/embedder_controller.go b/controllers/base/embedder_controller.go index 26c5355c8..580871128 100644 --- a/controllers/base/embedder_controller.go +++ b/controllers/base/embedder_controller.go @@ -111,6 +111,20 @@ func (r *EmbedderReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c logger.Info("Remove Embedder done") return ctrl.Result{}, nil } + if instance.Labels == nil { + instance.Labels = make(map[string]string) + } + providerType := instance.Spec.Provider.GetType() + + if _type, ok := instance.Labels[arcadiav1alpha1.ProviderLabel]; !ok || _type != string(providerType) { + instance.Labels[arcadiav1alpha1.ProviderLabel] = string(providerType) + err := r.Client.Update(ctx, instance) + if err != nil { + logger.Error(err, "failed to update embedder labels", "providerType", providerType) + } + return ctrl.Result{}, err + } + if err := r.CheckEmbedder(ctx, logger, instance); err != nil { return ctrl.Result{RequeueAfter: waitMedium}, err } diff --git a/controllers/base/llm_controller.go b/controllers/base/llm_controller.go index 34452ac0c..5c3ace8d1 100644 --- a/controllers/base/llm_controller.go +++ b/controllers/base/llm_controller.go @@ -108,6 +108,18 @@ func (r *LLMReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R logger.Info("Remove LLM done") return ctrl.Result{}, nil } + if instance.Labels == nil { + instance.Labels = make(map[string]string) + } + providerType := instance.Spec.Provider.GetType() + if _type, ok := instance.Labels[arcadiav1alpha1.ProviderLabel]; !ok || _type != string(providerType) { + instance.Labels[arcadiav1alpha1.ProviderLabel] = string(providerType) + err := r.Client.Update(ctx, instance) + if err != nil { + logger.Error(err, "failed to update llm lables", "providerType", providerType) + } + return ctrl.Result{}, err + } err := r.CheckLLM(ctx, logger, instance) if err != nil {