diff --git a/Dockerfile b/Dockerfile index 3d2265294..3279488c5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,6 @@ COPY go.mod go.mod COPY go.sum go.sum # cache deps before building and copying source so that we don't need to re-download as much # and so that source changes don't invalidate our downloaded layer -RUN go env -w GOPROXY=https://goproxy.cn,direct RUN go mod download # Copy the go source diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 38c5a4d6d..b58fd2a64 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.1.39 +version: 0.1.40 appVersion: "0.0.1" keywords: diff --git a/deploy/charts/arcadia/templates/post-models.yaml b/deploy/charts/arcadia/templates/post-models.yaml index 45cd2e2cf..d25509b51 100644 --- a/deploy/charts/arcadia/templates/post-models.yaml +++ b/deploy/charts/arcadia/templates/post-models.yaml @@ -2,11 +2,12 @@ apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Model metadata: name: baichuan2-7b-chat + namespace: {{ .Release.Namespace }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "1" spec: - displayName: "baichuan2-7b" + displayName: "baichuan2-7b-chat" description: | Baichuan2 为百川智能推出的新一代开源大语言模型,采用 2.6 万亿 Tokens 的高质量语料训练。 模型在通用、法律、医疗、数学、代码和多语言翻译六个领域的中英文和多语言权威数据集上进行了广泛测试,取得同尺寸中显著的优秀效果。 @@ -29,6 +30,7 @@ apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Model metadata: name: chatglm2-6b + namespace: {{ .Release.Namespace }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "1" @@ -59,6 +61,7 @@ apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Model metadata: name: qwen-7b-chat + namespace: {{ .Release.Namespace }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "1" @@ -90,6 +93,7 @@ apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Model metadata: name: bge-large-zh-v1.5 + namespace: {{ .Release.Namespace }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "1" @@ -121,6 +125,7 @@ apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: Model metadata: name: m3e-base + namespace: {{ .Release.Namespace }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "1" diff --git a/deploy/charts/arcadia/templates/resource_reader.yaml b/deploy/charts/arcadia/templates/resource_reader.yaml new file mode 100644 index 000000000..9ac954472 --- /dev/null +++ b/deploy/charts/arcadia/templates/resource_reader.yaml @@ -0,0 +1,38 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: kubeagi-reader + namespace: {{ .Release.Namespace }} +rules: +- apiGroups: + - arcadia.kubeagi.k8s.com.cn + resources: + - models + verbs: + - get + - list +- apiGroups: + - arcadia.kubeagi.k8s.com.cn + resources: + - models/finalizers + verbs: + - update +- apiGroups: + - arcadia.kubeagi.k8s.com.cn + resources: + - models/status + verbs: + - get +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: kubeagi-reader + namespace: {{ .Release.Namespace }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: kubeagi-reader +subjects: +- kind: Group + name: resource-reader \ No newline at end of file diff --git a/graphql-server/go-server/config/config.go b/graphql-server/go-server/config/config.go index f7a49f84b..1ba0bb2f4 100644 --- a/graphql-server/go-server/config/config.go +++ b/graphql-server/go-server/config/config.go @@ -17,13 +17,18 @@ package config import ( "flag" + "os" "k8s.io/klog/v2" "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/dataprocessing" ) +var s = &ServerConfig{} + type ServerConfig struct { + SystemNamespace string + Host string Port int EnablePlayground bool @@ -36,7 +41,7 @@ type ServerConfig struct { } func NewServerFlags() ServerConfig { - s := ServerConfig{} + flag.StringVar(&s.SystemNamespace, "system-namespace", os.Getenv("POD_NAMESPACE"), "system namespace where kubeagi has been installed") flag.StringVar(&s.Host, "host", "", "bind to the host, default is 0.0.0.0") flag.IntVar(&s.Port, "port", 8081, "service listening port") flag.BoolVar(&s.EnablePlayground, "enable-playground", false, "enable the graphql playground") @@ -52,5 +57,9 @@ func NewServerFlags() ServerConfig { flag.Parse() dataprocessing.Init(s.DataProcessURL) - return s + return *s +} + +func GetConfig() ServerConfig { + return *s } diff --git a/graphql-server/go-server/graph/generated/generated.go b/graphql-server/go-server/graph/generated/generated.go index 0719c7725..4a53ae227 100644 --- a/graphql-server/go-server/graph/generated/generated.go +++ b/graphql-server/go-server/graph/generated/generated.go @@ -333,6 +333,7 @@ type ComplexityRoot struct { Name func(childComplexity int) int Namespace func(childComplexity int) int Status func(childComplexity int) int + SystemModel func(childComplexity int) int Types func(childComplexity int) int UpdateTimestamp func(childComplexity int) int } @@ -345,7 +346,7 @@ type ComplexityRoot struct { ModelQuery struct { GetModel func(childComplexity int, name string, namespace string) int - ListModels func(childComplexity int, input ListCommonInput) int + ListModels func(childComplexity int, input ListModelInput) int } Mutation struct { @@ -553,7 +554,7 @@ type ModelMutationResolver interface { } type ModelQueryResolver interface { GetModel(ctx context.Context, obj *ModelQuery, name string, namespace string) (*Model, error) - ListModels(ctx context.Context, obj *ModelQuery, input ListCommonInput) (*PaginatedResult, error) + ListModels(ctx context.Context, obj *ModelQuery, input ListModelInput) (*PaginatedResult, error) } type MutationResolver interface { Hello(ctx context.Context, name string) (string, error) @@ -1989,6 +1990,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Model.Status(childComplexity), true + case "Model.systemModel": + if e.complexity.Model.SystemModel == nil { + break + } + + return e.complexity.Model.SystemModel(childComplexity), true + case "Model.types": if e.complexity.Model.Types == nil { break @@ -2061,7 +2069,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.ModelQuery.ListModels(childComplexity, args["input"].(ListCommonInput)), true + return e.complexity.ModelQuery.ListModels(childComplexity, args["input"].(ListModelInput)), true case "Mutation.dataProcess": if e.complexity.Mutation.DataProcess == nil { @@ -2769,6 +2777,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputListCommonInput, ec.unmarshalInputListDatasetInput, ec.unmarshalInputListKnowledgeBaseInput, + ec.unmarshalInputListModelInput, ec.unmarshalInputListVersionedDatasetInput, ec.unmarshalInputListWorkerInput, ec.unmarshalInputOssInput, @@ -4045,6 +4054,13 @@ type Model { """ namespace: String! + """ + 模型是否是由系统提供 + 规则: 如果为true,则是系统系统的。 + 规则: 如果是系统提供的模型,不允许修改 + """ + systemModel: Boolean + """一些用于标记,选择的的标签""" labels: Map """添加一些辅助性记录信息""" @@ -4135,9 +4151,41 @@ type ModelMutation { deleteModels(input: DeleteCommonInput): Void } +input ListModelInput { + namespace: String! + + """ + 是否包含系统提供的模型 + 规则: 为true时,代表将同时获取系统提供的模型 + 规则: 默认为false + """ + systemModel: Boolean + + """ + 关键词: 模糊匹配 + """ + keyword: String + + """标签选择器""" + labelSelector: String + """字段选择器""" + fieldSelector: String + """ + 分页页码, + 规则: 从1开始,默认是1 + """ + page: Int + + """ + 每页数量, + 规则: 默认10 + """ + pageSize: Int +} + type ModelQuery { getModel(name: String!, namespace: String!): Model! - listModels(input: ListCommonInput!): PaginatedResult! + listModels(input: ListModelInput!): PaginatedResult! } extend type Mutation { @@ -4532,6 +4580,10 @@ input ListWorkerInput { """ pageSize: Int + """ + worker对应的模型类型 + 规则: 模型分为embedding和llm两大类。如果两者都有,则通过逗号隔开,如: "embedding,llm" + """ modelTypes: String } @@ -5157,10 +5209,10 @@ func (ec *executionContext) field_ModelQuery_getModel_args(ctx context.Context, func (ec *executionContext) field_ModelQuery_listModels_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} - var arg0 ListCommonInput + var arg0 ListModelInput if tmp, ok := rawArgs["input"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) - arg0, err = ec.unmarshalNListCommonInput2githubᚗcomᚋkubeagiᚋarcadiaᚋgraphqlᚑserverᚋgoᚑserverᚋgraphᚋgeneratedᚐListCommonInput(ctx, tmp) + arg0, err = ec.unmarshalNListModelInput2githubᚗcomᚋkubeagiᚋarcadiaᚋgraphqlᚑserverᚋgoᚑserverᚋgraphᚋgeneratedᚐListModelInput(ctx, tmp) if err != nil { return nil, err } @@ -13478,6 +13530,47 @@ func (ec *executionContext) fieldContext_Model_namespace(ctx context.Context, fi return fc, nil } +func (ec *executionContext) _Model_systemModel(ctx context.Context, field graphql.CollectedField, obj *Model) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Model_systemModel(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.SystemModel, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*bool) + fc.Result = res + return ec.marshalOBoolean2ᚖbool(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Model_systemModel(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Model", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Boolean does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Model_labels(ctx context.Context, field graphql.CollectedField, obj *Model) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Model_labels(ctx, field) if err != nil { @@ -13962,6 +14055,8 @@ func (ec *executionContext) fieldContext_ModelMutation_createModel(ctx context.C return ec.fieldContext_Model_name(ctx, field) case "namespace": return ec.fieldContext_Model_namespace(ctx, field) + case "systemModel": + return ec.fieldContext_Model_systemModel(ctx, field) case "labels": return ec.fieldContext_Model_labels(ctx, field) case "annotations": @@ -14045,6 +14140,8 @@ func (ec *executionContext) fieldContext_ModelMutation_updateModel(ctx context.C return ec.fieldContext_Model_name(ctx, field) case "namespace": return ec.fieldContext_Model_namespace(ctx, field) + case "systemModel": + return ec.fieldContext_Model_systemModel(ctx, field) case "labels": return ec.fieldContext_Model_labels(ctx, field) case "annotations": @@ -14180,6 +14277,8 @@ func (ec *executionContext) fieldContext_ModelQuery_getModel(ctx context.Context return ec.fieldContext_Model_name(ctx, field) case "namespace": return ec.fieldContext_Model_namespace(ctx, field) + case "systemModel": + return ec.fieldContext_Model_systemModel(ctx, field) case "labels": return ec.fieldContext_Model_labels(ctx, field) case "annotations": @@ -14232,7 +14331,7 @@ func (ec *executionContext) _ModelQuery_listModels(ctx context.Context, field gr }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.ModelQuery().ListModels(rctx, obj, fc.Args["input"].(ListCommonInput)) + return ec.resolvers.ModelQuery().ListModels(rctx, obj, fc.Args["input"].(ListModelInput)) }) if err != nil { ec.Error(ctx, err) @@ -22074,6 +22173,89 @@ func (ec *executionContext) unmarshalInputListKnowledgeBaseInput(ctx context.Con return it, nil } +func (ec *executionContext) unmarshalInputListModelInput(ctx context.Context, obj interface{}) (ListModelInput, error) { + var it ListModelInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"namespace", "systemModel", "keyword", "labelSelector", "fieldSelector", "page", "pageSize"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "namespace": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("namespace")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Namespace = data + case "systemModel": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("systemModel")) + data, err := ec.unmarshalOBoolean2ᚖbool(ctx, v) + if err != nil { + return it, err + } + it.SystemModel = data + case "keyword": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("keyword")) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + it.Keyword = data + case "labelSelector": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("labelSelector")) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + it.LabelSelector = data + case "fieldSelector": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("fieldSelector")) + data, err := ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + it.FieldSelector = data + case "page": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("page")) + data, err := ec.unmarshalOInt2ᚖint(ctx, v) + if err != nil { + return it, err + } + it.Page = data + case "pageSize": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("pageSize")) + data, err := ec.unmarshalOInt2ᚖint(ctx, v) + if err != nil { + return it, err + } + it.PageSize = data + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputListVersionedDatasetInput(ctx context.Context, obj interface{}) (ListVersionedDatasetInput, error) { var it ListVersionedDatasetInput asMap := map[string]interface{}{} @@ -25867,6 +26049,8 @@ func (ec *executionContext) _Model(ctx context.Context, sel ast.SelectionSet, ob if out.Values[i] == graphql.Null { atomic.AddUint32(&out.Invalids, 1) } + case "systemModel": + out.Values[i] = ec._Model_systemModel(ctx, field, obj) case "labels": out.Values[i] = ec._Model_labels(ctx, field, obj) case "annotations": @@ -28109,6 +28293,11 @@ func (ec *executionContext) unmarshalNListKnowledgeBaseInput2githubᚗcomᚋkube return res, graphql.ErrorOnPath(ctx, err) } +func (ec *executionContext) unmarshalNListModelInput2githubᚗcomᚋkubeagiᚋarcadiaᚋgraphqlᚑserverᚋgoᚑserverᚋgraphᚋgeneratedᚐListModelInput(ctx context.Context, v interface{}) (ListModelInput, error) { + res, err := ec.unmarshalInputListModelInput(ctx, v) + return res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) unmarshalNListVersionedDatasetInput2githubᚗcomᚋkubeagiᚋarcadiaᚋgraphqlᚑserverᚋgoᚑserverᚋgraphᚋgeneratedᚐListVersionedDatasetInput(ctx context.Context, v interface{}) (ListVersionedDatasetInput, error) { res, err := ec.unmarshalInputListVersionedDatasetInput(ctx, v) return res, graphql.ErrorOnPath(ctx, err) diff --git a/graphql-server/go-server/graph/generated/models_gen.go b/graphql-server/go-server/graph/generated/models_gen.go index e0aa8dc75..2aae74774 100644 --- a/graphql-server/go-server/graph/generated/models_gen.go +++ b/graphql-server/go-server/graph/generated/models_gen.go @@ -727,6 +727,26 @@ type ListKnowledgeBaseInput struct { Keyword *string `json:"keyword,omitempty"` } +type ListModelInput struct { + Namespace string `json:"namespace"` + // 是否包含系统提供的模型 + // 规则: 为true时,代表将同时获取系统提供的模型 + // 规则: 默认为false + SystemModel *bool `json:"systemModel,omitempty"` + // 关键词: 模糊匹配 + Keyword *string `json:"keyword,omitempty"` + // 标签选择器 + LabelSelector *string `json:"labelSelector,omitempty"` + // 字段选择器 + FieldSelector *string `json:"fieldSelector,omitempty"` + // 分页页码, + // 规则: 从1开始,默认是1 + Page *int `json:"page,omitempty"` + // 每页数量, + // 规则: 默认10 + PageSize *int `json:"pageSize,omitempty"` +} + type ListVersionedDatasetInput struct { Name *string `json:"name,omitempty"` Namespace *string `json:"namespace,omitempty"` @@ -753,7 +773,9 @@ type ListWorkerInput struct { Page *int `json:"page,omitempty"` // 每页数量, // 规则: 默认10 - PageSize *int `json:"pageSize,omitempty"` + PageSize *int `json:"pageSize,omitempty"` + // worker对应的模型类型 + // 规则: 模型分为embedding和llm两大类。如果两者都有,则通过逗号隔开,如: "embedding,llm" ModelTypes *string `json:"modelTypes,omitempty"` } @@ -768,6 +790,10 @@ type Model struct { // 规则: 获取当前项目对应的命名空间 // 规则: 非空 Namespace string `json:"namespace"` + // 模型是否是由系统提供 + // 规则: 如果为true,则是系统系统的。 + // 规则: 如果是系统提供的模型,不允许修改 + SystemModel *bool `json:"systemModel,omitempty"` // 一些用于标记,选择的的标签 Labels map[string]interface{} `json:"labels,omitempty"` // 添加一些辅助性记录信息 diff --git a/graphql-server/go-server/graph/impl/model.resolvers.go b/graphql-server/go-server/graph/impl/model.resolvers.go index 8090d96c6..9532a1c76 100644 --- a/graphql-server/go-server/graph/impl/model.resolvers.go +++ b/graphql-server/go-server/graph/impl/model.resolvers.go @@ -60,7 +60,7 @@ func (r *modelQueryResolver) GetModel(ctx context.Context, obj *generated.ModelQ } // ListModels is the resolver for the listModels field. -func (r *modelQueryResolver) ListModels(ctx context.Context, obj *generated.ModelQuery, input generated.ListCommonInput) (*generated.PaginatedResult, error) { +func (r *modelQueryResolver) ListModels(ctx context.Context, obj *generated.ModelQuery, input generated.ListModelInput) (*generated.PaginatedResult, error) { c, err := getClientFromCtx(ctx) if err != nil { return nil, err diff --git a/graphql-server/go-server/graph/schema/model.gql b/graphql-server/go-server/graph/schema/model.gql index d5747084d..0cdfbc0fe 100644 --- a/graphql-server/go-server/graph/schema/model.gql +++ b/graphql-server/go-server/graph/schema/model.gql @@ -1,5 +1,5 @@ # list -query listModels($input: ListCommonInput!,$filesInput: FileFilter){ +query listModels($input: ListModelInput!,$filesInput: FileFilter){ Model { listModels(input: $input) { totalCount @@ -11,6 +11,7 @@ query listModels($input: ListCommonInput!,$filesInput: FileFilter){ creationTimestamp name namespace + system labels annotations creator @@ -47,6 +48,7 @@ query getModel($name: String!, $namespace: String!,$filesInput: FileFilter) { creationTimestamp name namespace + system labels annotations creator @@ -81,6 +83,7 @@ mutation createModel($input: CreateModelInput!) { creationTimestamp name namespace + system labels annotations creator @@ -101,6 +104,7 @@ mutation updateModel($input: UpdateModelInput) { creationTimestamp name namespace + system labels annotations creator diff --git a/graphql-server/go-server/graph/schema/model.graphqls b/graphql-server/go-server/graph/schema/model.graphqls index 67c4d1783..5cdcec173 100644 --- a/graphql-server/go-server/graph/schema/model.graphqls +++ b/graphql-server/go-server/graph/schema/model.graphqls @@ -18,6 +18,13 @@ type Model { """ namespace: String! + """ + 模型是否是由系统提供 + 规则: 如果为true,则是系统系统的。 + 规则: 如果是系统提供的模型,不允许修改 + """ + systemModel: Boolean + """一些用于标记,选择的的标签""" labels: Map """添加一些辅助性记录信息""" @@ -108,9 +115,41 @@ type ModelMutation { deleteModels(input: DeleteCommonInput): Void } +input ListModelInput { + namespace: String! + + """ + 是否包含系统提供的模型 + 规则: 为true时,代表将同时获取系统提供的模型 + 规则: 默认为false + """ + systemModel: Boolean + + """ + 关键词: 模糊匹配 + """ + keyword: String + + """标签选择器""" + labelSelector: String + """字段选择器""" + fieldSelector: String + """ + 分页页码, + 规则: 从1开始,默认是1 + """ + page: Int + + """ + 每页数量, + 规则: 默认10 + """ + pageSize: Int +} + type ModelQuery { getModel(name: String!, namespace: String!): Model! - listModels(input: ListCommonInput!): PaginatedResult! + listModels(input: ListModelInput!): PaginatedResult! } extend type Mutation { diff --git a/graphql-server/go-server/graph/schema/worker.gql b/graphql-server/go-server/graph/schema/worker.gql index bf876e8cc..0c89b5d61 100644 --- a/graphql-server/go-server/graph/schema/worker.gql +++ b/graphql-server/go-server/graph/schema/worker.gql @@ -48,6 +48,11 @@ query getWorker($name: String!, $namespace: String!) { updateTimestamp model modelTypes + resources { + cpu + memory + nvidiaGPU + } } } } diff --git a/graphql-server/go-server/graph/schema/worker.graphqls b/graphql-server/go-server/graph/schema/worker.graphqls index f9064bff8..cfc800c91 100644 --- a/graphql-server/go-server/graph/schema/worker.graphqls +++ b/graphql-server/go-server/graph/schema/worker.graphqls @@ -167,6 +167,10 @@ input ListWorkerInput { """ pageSize: Int + """ + worker对应的模型类型 + 规则: 模型分为embedding和llm两大类。如果两者都有,则通过逗号隔开,如: "embedding,llm" + """ modelTypes: String } diff --git a/graphql-server/go-server/pkg/common/schema.go b/graphql-server/go-server/pkg/common/schema.go index 447d8c4fc..15430b4a6 100644 --- a/graphql-server/go-server/pkg/common/schema.go +++ b/graphql-server/go-server/pkg/common/schema.go @@ -57,6 +57,16 @@ var ( Version: v1alpha1.GroupVersion.Version, Resource: "datasets", }, + "model": { + Group: v1alpha1.GroupVersion.Group, + Version: v1alpha1.GroupVersion.Version, + Resource: "models", + }, + "worker": { + Group: v1alpha1.GroupVersion.Group, + Version: v1alpha1.GroupVersion.Version, + Resource: "workers", + }, "application": { Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, diff --git a/graphql-server/go-server/pkg/embedder/embedder.go b/graphql-server/go-server/pkg/embedder/embedder.go index b14d2446d..f55c040f3 100644 --- a/graphql-server/go-server/pkg/embedder/embedder.go +++ b/graphql-server/go-server/pkg/embedder/embedder.go @@ -233,7 +233,12 @@ func ListEmbedders(ctx context.Context, c dynamic.Interface, input generated.Lis totalCount := len(us.Items) result := make([]generated.PageNode, 0, pageSize) - for _, u := range us.Items { + pageStart := (page - 1) * pageSize + for index, u := range us.Items { + // skip if smaller than the start index + if index < pageStart { + continue + } m := embedder2model(&u) // filter based on `keyword` if keyword != "" { diff --git a/graphql-server/go-server/pkg/model/model.go b/graphql-server/go-server/pkg/model/model.go index 4011e6512..9d136b465 100644 --- a/graphql-server/go-server/pkg/model/model.go +++ b/graphql-server/go-server/pkg/model/model.go @@ -31,7 +31,9 @@ import ( "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/graphql-server/go-server/config" "github.com/kubeagi/arcadia/graphql-server/go-server/graph/generated" + "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/common" "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/minio" graphqlutils "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/utils" "github.com/kubeagi/arcadia/pkg/utils" @@ -39,37 +41,40 @@ import ( ) func obj2model(obj *unstructured.Unstructured) *generated.Model { - id := string(obj.GetUID()) - creationtimestamp := obj.GetCreationTimestamp().Time - displayName, _, _ := unstructured.NestedString(obj.Object, "spec", "displayName") + model := &v1alpha1.Model{} + if err := utils.UnstructuredToStructured(obj, model); err != nil { + return &generated.Model{} + } - types, _, _ := unstructured.NestedString(obj.Object, "spec", "types") - description, _, _ := unstructured.NestedString(obj.Object, "spec", "description") - status := "" - var updateTime time.Time - conditions, found, _ := unstructured.NestedSlice(obj.Object, "status", "conditions") - if found && len(conditions) > 0 { - condition, ok := conditions[0].(map[string]interface{}) - if ok { - timeStr, _ := condition["lastTransitionTime"].(string) - updateTime, _ = utils.RFC3339Time(timeStr) - status, _ = condition["status"].(string) - } - } else { - status = "unknow" + id := string(model.GetUID()) + creationtimestamp := model.GetCreationTimestamp().Time + + // conditioned status + condition := model.Status.GetCondition(v1alpha1.TypeReady) + updateTime := condition.LastTransitionTime.Time + + // Unknown,Pending ,WorkerRunning ,Error + status := string(condition.Reason) + + var systemModel bool + if obj.GetNamespace() == config.GetConfig().SystemNamespace { + systemModel = true } + md := generated.Model{ ID: &id, Name: obj.GetName(), Namespace: obj.GetNamespace(), + Creator: &model.Spec.Creator, + SystemModel: &systemModel, Labels: graphqlutils.MapStr2Any(obj.GetLabels()), Annotations: graphqlutils.MapStr2Any(obj.GetAnnotations()), - DisplayName: &displayName, - Description: &description, - Status: &status, - Types: types, + DisplayName: &model.Spec.DisplayName, + Description: &model.Spec.Description, + Types: model.Spec.Types, CreationTimestamp: &creationtimestamp, UpdateTimestamp: &updateTime, + Status: &status, } return &md } @@ -177,7 +182,7 @@ func DeleteModels(ctx context.Context, c dynamic.Interface, input *generated.Del return nil, nil } -func ListModels(ctx context.Context, c dynamic.Interface, input generated.ListCommonInput) (*generated.PaginatedResult, error) { +func ListModels(ctx context.Context, c dynamic.Interface, input generated.ListModelInput) (*generated.PaginatedResult, error) { keyword, labelSelector, fieldSelector := "", "", "" page, pageSize := 1, 10 if input.Keyword != nil { @@ -196,24 +201,43 @@ func ListModels(ctx context.Context, c dynamic.Interface, input generated.ListCo pageSize = *input.PageSize } - dsSchema := schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "models"} listOptions := metav1.ListOptions{ LabelSelector: labelSelector, FieldSelector: fieldSelector, } - us, err := c.Resource(dsSchema).Namespace(input.Namespace).List(ctx, listOptions) + models, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "model")).Namespace(input.Namespace).List(ctx, listOptions) if err != nil { return nil, err } + // sort by creation time - sort.Slice(us.Items, func(i, j int) bool { - return us.Items[i].GetCreationTimestamp().After(us.Items[j].GetCreationTimestamp().Time) + sort.Slice(models.Items, func(i, j int) bool { + return models.Items[i].GetCreationTimestamp().After(models.Items[j].GetCreationTimestamp().Time) }) - totalCount := len(us.Items) + // list models in kubeagi system namespace + if input.SystemModel != nil && *input.SystemModel { + systemModels, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "model")).Namespace(config.GetConfig().SystemNamespace).List(ctx, listOptions) + if err != nil { + return nil, err + } + // sort by creation time + sort.Slice(systemModels.Items, func(i, j int) bool { + return systemModels.Items[i].GetCreationTimestamp().After(systemModels.Items[j].GetCreationTimestamp().Time) + }) + models.Items = append(systemModels.Items, models.Items...) + } + + totalCount := len(models.Items) result := make([]generated.PageNode, 0, pageSize) - for _, u := range us.Items { + pageStart := (page - 1) * pageSize + for index, u := range models.Items { + // skip if smaller than the start index + if index < pageStart { + continue + } + m := obj2model(&u) // filter based on `keyword` if keyword != "" { diff --git a/graphql-server/go-server/pkg/worker/worker.go b/graphql-server/go-server/pkg/worker/worker.go index ea9995fba..e068b4b84 100644 --- a/graphql-server/go-server/pkg/worker/worker.go +++ b/graphql-server/go-server/pkg/worker/worker.go @@ -26,11 +26,11 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/graphql-server/go-server/graph/generated" + "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/common" gqlmodel "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/model" graphqlutils "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/utils" "github.com/kubeagi/arcadia/pkg/utils" @@ -40,10 +40,6 @@ const ( NvidiaGPU = "nvidia.com/gpu" ) -var ( - scheme = schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "workers"} -) - func worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Unstructured) *generated.Worker { worker := &v1alpha1.Worker{} if err := utils.UnstructuredToStructured(obj, worker); err != nil { @@ -79,6 +75,7 @@ func worker2model(ctx context.Context, c dynamic.Interface, obj *unstructured.Un ID: &id, Name: worker.Name, Namespace: worker.Namespace, + Creator: &worker.Spec.Creator, Labels: graphqlutils.MapStr2Any(obj.GetLabels()), Annotations: graphqlutils.MapStr2Any(obj.GetAnnotations()), DisplayName: &worker.Spec.DisplayName, @@ -147,7 +144,7 @@ func CreateWorker(ctx context.Context, c dynamic.Interface, input generated.Crea if err != nil { return nil, err } - obj, err := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "workers"}). + obj, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "worker")). Namespace(input.Namespace).Create(ctx, &unstructured.Unstructured{Object: unstructuredWorker}, metav1.CreateOptions{}) if err != nil { return nil, err @@ -156,7 +153,7 @@ func CreateWorker(ctx context.Context, c dynamic.Interface, input generated.Crea } func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.UpdateWorkerInput) (*generated.Worker, error) { - obj, err := c.Resource(scheme).Namespace(input.Namespace).Get(ctx, input.Name, metav1.GetOptions{}) + obj, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "worker")).Namespace(input.Namespace).Get(ctx, input.Name, metav1.GetOptions{}) if err != nil { return nil, err } @@ -198,7 +195,12 @@ func UpdateWorker(ctx context.Context, c dynamic.Interface, input *generated.Upd return nil, err } - updatedObject, err := c.Resource(scheme).Namespace(input.Namespace).Update(ctx, &unstructured.Unstructured{Object: unstructuredWorker}, metav1.UpdateOptions{}) + updatedObject, err := common.ResouceUpdate(ctx, c, generated.TypedObjectReferenceInput{ + APIGroup: &common.ArcadiaAPIGroup, + Kind: "Worker", + Name: input.Name, + Namespace: &input.Namespace, + }, unstructuredWorker, metav1.UpdateOptions{}) if err != nil { return nil, err } @@ -218,7 +220,7 @@ func DeleteWorkers(ctx context.Context, c dynamic.Interface, input *generated.De if input.LabelSelector != nil { labelSelector = *input.LabelSelector } - resource := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "workers"}) + resource := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "worker")) if name != "" { err := resource.Namespace(input.Namespace).Delete(ctx, name, metav1.DeleteOptions{}) if err != nil { @@ -258,12 +260,11 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW pageSize = *input.PageSize } - workerSchema := schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "workers"} listOptions := metav1.ListOptions{ LabelSelector: labelSelector, FieldSelector: fieldSelector, } - us, err := c.Resource(workerSchema).Namespace(input.Namespace).List(ctx, listOptions) + us, err := c.Resource(common.SchemaOf(&common.ArcadiaAPIGroup, "worker")).Namespace(input.Namespace).List(ctx, listOptions) if err != nil { return nil, err } @@ -275,7 +276,12 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW totalCount := len(us.Items) result := make([]generated.PageNode, 0, pageSize) - for _, u := range us.Items { + pageStart := (page - 1) * pageSize + for index, u := range us.Items { + // skip if smaller than the start index + if index < pageStart { + continue + } m := worker2model(ctx, c, &u) // filter based on `keyword` if keyword != "" { @@ -310,8 +316,12 @@ func ListWorkers(ctx context.Context, c dynamic.Interface, input generated.ListW } func ReadWorker(ctx context.Context, c dynamic.Interface, name, namespace string) (*generated.Worker, error) { - resource := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "workers"}) - u, err := resource.Namespace(namespace).Get(ctx, name, metav1.GetOptions{}) + u, err := common.ResouceGet(ctx, c, generated.TypedObjectReferenceInput{ + APIGroup: &common.ArcadiaAPIGroup, + Kind: "Worker", + Name: name, + Namespace: &namespace, + }, metav1.GetOptions{}) if err != nil { return nil, err }