Skip to content

Commit

Permalink
fix: too long to call listModelServices
Browse files Browse the repository at this point in the history
  • Loading branch information
0xff-dev committed Jan 19, 2024
1 parent 360066b commit 058cf46
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 100 deletions.
2 changes: 2 additions & 0 deletions api/base/v1alpha1/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ const (
ProviderTypeUnknown ProviderType = "unknown"
ProviderType3rdParty ProviderType = "3rd_party"
ProviderTypeWorker ProviderType = "worker"

ProviderLabel = Group + "/provider"
)

// Provider defines how to prvoide the service
Expand Down
2 changes: 1 addition & 1 deletion apiserver/graph/impl/worker.resolvers.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions apiserver/pkg/common/common_list_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import (
func ListReources(list *unstructured.UnstructuredList, page, pageSize int, converter ResourceConverter, options ...ResourceFilter) (*generated.PaginatedResult, error) {
index, optIndex := 0, 0
for i := range list.Items {
optIndex = 0
for ; optIndex < len(options); optIndex++ {
for optIndex = 0; optIndex < len(options); optIndex++ {
if !options[optIndex](&list.Items[i]) {
break
}
Expand Down
26 changes: 26 additions & 0 deletions apiserver/pkg/modelservice/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package modelservice
import (
"github.com/kubeagi/arcadia/apiserver/graph/generated"
"github.com/kubeagi/arcadia/apiserver/pkg/common"
"github.com/kubeagi/arcadia/pkg/llms"
)

// Embedder2ModelService convert unstructured `CR Embedder` to graphql model `ModelService`
Expand Down Expand Up @@ -80,3 +81,28 @@ func LLM2ModelService(llm *generated.Llm) *generated.ModelService {
}
return ms
}

func Worker2ModelService(worker *generated.Worker) *generated.ModelService {
ms := &generated.ModelService{
ID: worker.ID,
Name: worker.Name,
Namespace: worker.Namespace,
CreationTimestamp: worker.CreationTimestamp,
UpdateTimestamp: worker.UpdateTimestamp,
Creator: worker.Creator,
DisplayName: worker.DisplayName,
Description: worker.Description,
ProviderType: new(string),
Types: new(string),
APIType: new(string),
EmbeddingModels: []string{*worker.ID},
LlmModels: []string{*worker.ID},
Status: worker.Status,
Message: worker.Message,
}

*ms.ProviderType = "worker"
*ms.Types = common.ModelTypeAll
*ms.APIType = string(llms.OpenAI)
return ms
}
166 changes: 81 additions & 85 deletions apiserver/pkg/modelservice/modelservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ import (
"github.com/tmc/langchaingo/llms"
"k8s.io/client-go/dynamic"

"github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/apiserver/graph/generated"
"github.com/kubeagi/arcadia/apiserver/pkg/common"
"github.com/kubeagi/arcadia/apiserver/pkg/embedder"
"github.com/kubeagi/arcadia/apiserver/pkg/llm"
"github.com/kubeagi/arcadia/apiserver/pkg/worker"
llmspkg "github.com/kubeagi/arcadia/pkg/llms"
"github.com/kubeagi/arcadia/pkg/llms/openai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
)
Expand Down Expand Up @@ -248,57 +251,91 @@ func ReadModelService(ctx context.Context, c dynamic.Interface, name string, nam
// ListModelServices based on input
func ListModelServices(ctx context.Context, c dynamic.Interface, input *generated.ListModelServiceInput) (*generated.PaginatedResult, error) {
// use `UnlimitedPageSize` so we can get all llms and embeddings
query := generated.ListCommonInput{Page: input.Page, PageSize: &common.UnlimitedPageSize, Namespace: input.Namespace, Keyword: input.Keyword}
notWorkerSelector := fmt.Sprintf("%s=%s", v1alpha1.ProviderLabel, v1alpha1.ProviderType3rdParty)

// list all llms
llmList, err := llm.ListLLMs(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode {
llm, ok := a.(*generated.Llm)
if !ok {
return nil
query := generated.ListCommonInput{
Page: input.Page,
PageSize: &common.UnlimitedPageSize,
Namespace: input.Namespace,
Keyword: input.Keyword,
LabelSelector: &notWorkerSelector,
}

newNodeList := make([]generated.PageNode, 0)

if input.ProviderType == nil || *input.ProviderType == string(v1alpha1.ProviderType3rdParty) {
exist := make(map[string]int)
if input.Types == nil || (*input.Types == common.ModelTypeAll || *input.Types == common.ModelTypeLLM) {
llmList, err := llm.ListLLMs(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode {
llm, ok := a.(*generated.Llm)
if !ok {
return nil
}
return LLM2ModelService(llm)
}))
if err != nil {
return nil, err
}
for _, n := range llmList.Nodes {
tmp := n.(*generated.ModelService)
if input.APIType != nil && *input.APIType != *tmp.APIType {
continue
}
newNodeList = append(newNodeList, tmp)
exist[tmp.Name] = len(newNodeList) - 1
}
}
// convert llm to modelserivce
return LLM2ModelService(llm)
}))
if err != nil {
return nil, err
}

// list all embedders
embedderList, err := embedder.ListEmbedders(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode {
embedder, ok := a.(*generated.Embedder)
if !ok {
return nil
if input.Types == nil || (*input.Types == common.ModelTypeAll || *input.Types == common.ModelTypeEmbedding) {
embedderList, err := embedder.ListEmbedders(ctx, c, query, common.WithPageNodeConvertFunc(func(a any) generated.PageNode {
embedder, ok := a.(*generated.Embedder)
if !ok {
return nil
}
return Embedder2ModelService(embedder)
}))
if err != nil {
return nil, err
}
for _, n := range embedderList.Nodes {
tmp := n.(*generated.ModelService)
if input.APIType != nil && *input.APIType != *tmp.APIType {
continue
}
if idx, ok := exist[tmp.Name]; ok {
t := newNodeList[idx].(*generated.ModelService)
t.Types = &common.ModelTypeAll
t.EmbeddingModels = tmp.EmbeddingModels
continue
}
newNodeList = append(newNodeList, tmp)
}
}
// convert embedder to modelserivce
return Embedder2ModelService(embedder)
}))
if err != nil {
return nil, err
}

// serviceList keeps all model services with combined
serviceMapList := make(map[string]*generated.ModelService)
for _, node := range append(llmList.Nodes, embedderList.Nodes...) {
ms, _ := node.(*generated.ModelService)
curr, ok := serviceMapList[ms.Name]
// if llm & embedder has same name,we treat it as `ModelTypeAll(llm,embedding)`
if ok {
ms.Types = &common.ModelTypeAll
// combine models provided by this model service
ms.LlmModels = append(ms.LlmModels, curr.LlmModels...)
ms.EmbeddingModels = append(ms.EmbeddingModels, curr.EmbeddingModels...)
if input.ProviderType == nil || *input.ProviderType == string(v1alpha1.ProviderTypeWorker) {
if input.APIType == nil || *input.APIType == string(llmspkg.OpenAI) {
workerQuery := generated.ListWorkerInput{
Page: query.Page,
PageSize: &common.UnlimitedPageSize,
Namespace: input.Namespace,
Keyword: input.Keyword,
}
workerList, err := worker.ListWorkers(ctx, c, workerQuery, false)
if err != nil {
return nil, err
}
for _, n := range workerList.Nodes {
tmp := n.(*generated.Worker)
newNodeList = append(newNodeList, Worker2ModelService(tmp))
}
}
serviceMapList[ms.Name] = ms
}

// newNodeList
newNodeList := make([]*generated.ModelService, 0, len(serviceMapList))
for _, node := range serviceMapList {
newNodeList = append(newNodeList, node)
}
// sort by creation timestamp
sort.Slice(newNodeList, func(i, j int) bool {
return newNodeList[i].CreationTimestamp.After(*newNodeList[j].CreationTimestamp)
a := newNodeList[i].(*generated.ModelService)
b := newNodeList[j].(*generated.ModelService)
return a.CreationTimestamp.After(*b.CreationTimestamp)
})

// return ModelService with the actual Page and PageSize
Expand All @@ -310,54 +347,13 @@ func ListModelServices(ctx context.Context, c dynamic.Interface, input *generate
pageSize = *input.PageSize
}

var totalCount int

result := make([]generated.PageNode, 0, pageSize)
pageStart := (page - 1) * pageSize

// index is the actual result length
var index int
for _, service := range newNodeList {
// Add filter conditions here
// 1. filter service types: llm or embedding or both
if input.Types != nil && *input.Types != "" {
if !strings.Contains(*service.Types, *input.Types) {
continue
}
}
// 2. filter provider type: worker or 3rd_party
if input.ProviderType != nil && *input.ProviderType != "" {
if service.ProviderType == nil || *service.ProviderType != *input.ProviderType {
continue
}
}
// 3. filter api type: openai or zhipuai
if input.APIType != nil && *input.APIType != "" {
if service.APIType == nil || *service.APIType != *input.APIType {
continue
}
}

// increase totalCount when service meets the filter conditions
totalCount++

// append result
if index >= pageStart && len(result) < pageSize {
result = append(result, service)
}

index++
}

end := page * pageSize
if end > totalCount {
end = totalCount
}
totalCount := len(newNodeList)
start, end := common.PagePosition(page, pageSize, totalCount)

return &generated.PaginatedResult{
TotalCount: totalCount,
HasNextPage: end < totalCount,
Nodes: result,
Nodes: newNodeList[start:end],
}, nil
}

Expand Down
53 changes: 41 additions & 12 deletions apiserver/pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ const (
NvidiaGPU = "nvidia.com/gpu"
)

func worker2modelConverter(ctx context.Context, c dynamic.Interface) func(*unstructured.Unstructured) (generated.PageNode, error) {
func worker2modelConverter(ctx context.Context, c dynamic.Interface, showModel bool) func(*unstructured.Unstructured) (generated.PageNode, error) {
return func(u *unstructured.Unstructured) (generated.PageNode, error) {
return Worker2model(ctx, c, u)
return Worker2model(ctx, c, u, showModel)
}
}

// Worker2model convert unstructured `CR Worker` to graphql model
func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) (*generated.Worker, error) {
func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured, showModel bool) (*generated.Worker, error) {
worker := &v1alpha1.Worker{}
if err := utils.UnstructuredToStructured(obj, worker); err != nil {
return nil, err
Expand Down Expand Up @@ -107,8 +107,6 @@ func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un

workerType := string(worker.Type())

api, _ := common.GetAPIServer(ctx, c, true)

// wrap Worker
w := generated.Worker{
ID: &id,
Expand All @@ -129,11 +127,11 @@ func Worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un
MatchExpressions: matchExpressions,
AdditionalEnvs: additionalEnvs,
ModelTypes: "unknown",
API: &api,
API: new(string),
}

// read worker's models
if worker.Spec.Model != nil {
if worker.Spec.Model != nil && showModel {
typedModel := worker.Model()
model, err := gqlmodel.ReadModel(ctx, c, typedModel.Name, *typedModel.Namespace)
if err != nil {
Expand Down Expand Up @@ -247,7 +245,15 @@ func CreateWorker(ctx context.Context, c dynamic.Interface, input generated.Crea
if err != nil {
return nil, err
}
return Worker2model(ctx, c, obj)

api, _ := common.GetAPIServer(ctx, c, true)

w, err := Worker2model(ctx, c, obj, true)
if err != nil {
return nil, err
}
*w.API = api
return w, nil
}

func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.UpdateWorkerInput) (*generated.Worker, error) {
Expand Down Expand Up @@ -344,8 +350,14 @@ func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.Upd
if err != nil {
return nil, err
}
api, _ := common.GetAPIServer(ctx, c, true)

return Worker2model(ctx, c, updatedObject)
w, err := Worker2model(ctx, c, updatedObject, true)
if err != nil {
return nil, err
}
*w.API = api
return w, nil
}

func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.DeleteCommonInput) (*string, error) {
Expand Down Expand Up @@ -379,7 +391,7 @@ func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.De
return nil, nil
}

func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListWorkerInput, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) {
func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListWorkerInput, showWorkerModel bool, listOpts ...common.ListOptionsFunc) (*generated.PaginatedResult, error) {
opts := common.DefaultListOptions()
for _, optFunc := range listOpts {
optFunc(opts)
Expand Down Expand Up @@ -415,7 +427,17 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW
if err != nil {
return nil, err
}
return common.ListReources(us, page, pageSize, worker2modelConverter(ctx, c), filter...)
list, err := common.ListReources(us, page, pageSize, worker2modelConverter(ctx, c, showWorkerModel), filter...)
if err != nil {
return nil, err
}
api, _ := common.GetAPIServer(ctx, c, true)

for i := range list.Nodes {
tmp := list.Nodes[i].(*generated.Worker)
*tmp.API = api
}
return list, nil
}

func ReadWorker(ctx context.Context, c dynamic.Interface, name, namespace string) (*generated.Worker, error) {
Expand All @@ -428,5 +450,12 @@ func ReadWorker(ctx context.Context, c dynamic.Interface, name, namespace string
if err != nil {
return nil, err
}
return Worker2model(ctx, c, u)
api, _ := common.GetAPIServer(ctx, c, true)

w, err := Worker2model(ctx, c, u, true)
if err != nil {
return nil, err
}
*w.API = api
return w, nil
}
Loading

0 comments on commit 058cf46

Please sign in to comment.