diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index b2ea60220..cd254e099 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -445,6 +445,7 @@ type ComplexityRoot struct { } ModelServiceQuery struct { + CheckModelService func(childComplexity int, input CreateModelServiceInput) int GetModelService func(childComplexity int, name string, namespace string, apiType string) int ListModelServices func(childComplexity int, input *ListModelService) int } @@ -676,6 +677,7 @@ type ModelServiceMutationResolver interface { type ModelServiceQueryResolver interface { GetModelService(ctx context.Context, obj *ModelServiceQuery, name string, namespace string, apiType string) (*ModelService, error) ListModelServices(ctx context.Context, obj *ModelServiceQuery, input *ListModelService) (*PaginatedResult, error) + CheckModelService(ctx context.Context, obj *ModelServiceQuery, input CreateModelServiceInput) (*ModelService, error) } type MutationResolver interface { Hello(ctx context.Context, name string) (string, error) @@ -2722,6 +2724,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelServiceMutation.UpdateModelService(childComplexity, args["input"].(*UpdateModelServiceInput)), true + case "ModelServiceQuery.checkModelService": + if e.complexity.ModelServiceQuery.CheckModelService == nil { + break + } + + args, err := ec.field_ModelServiceQuery_checkModelService_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.ModelServiceQuery.CheckModelService(childComplexity, args["input"].(CreateModelServiceInput)), true + case "ModelServiceQuery.getModelService": if e.complexity.ModelServiceQuery.GetModelService == nil { break @@ -5326,6 +5340,7 @@ extend type Mutation { type ModelServiceQuery { getModelService(name: String!, namespace: String!, apiType: String!): ModelService listModelServices(input: ListModelService): PaginatedResult! + checkModelService(input: CreateModelServiceInput!): ModelService! } extend type Query { @@ -6498,6 +6513,21 @@ func (ec *executionContext) field_ModelServiceMutation_updateModelService_args(c return args, nil } +func (ec *executionContext) field_ModelServiceQuery_checkModelService_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 CreateModelServiceInput + if tmp, ok := rawArgs["input"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("input")) + arg0, err = ec.unmarshalNCreateModelServiceInput2githubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐCreateModelServiceInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["input"] = arg0 + return args, nil +} + func (ec *executionContext) field_ModelServiceQuery_getModelService_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -19169,6 +19199,93 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_listModelServices(ctx return fc, nil } +func (ec *executionContext) _ModelServiceQuery_checkModelService(ctx context.Context, field graphql.CollectedField, obj *ModelServiceQuery) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelServiceQuery_checkModelService(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.ModelServiceQuery().CheckModelService(rctx, obj, fc.Args["input"].(CreateModelServiceInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*ModelService) + fc.Result = res + return ec.marshalNModelService2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐModelService(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ModelServiceQuery_checkModelService(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ModelServiceQuery", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_ModelService_id(ctx, field) + case "name": + return ec.fieldContext_ModelService_name(ctx, field) + case "namespace": + return ec.fieldContext_ModelService_namespace(ctx, field) + case "labels": + return ec.fieldContext_ModelService_labels(ctx, field) + case "annotations": + return ec.fieldContext_ModelService_annotations(ctx, field) + case "creator": + return ec.fieldContext_ModelService_creator(ctx, field) + case "displayName": + return ec.fieldContext_ModelService_displayName(ctx, field) + case "description": + return ec.fieldContext_ModelService_description(ctx, field) + case "types": + return ec.fieldContext_ModelService_types(ctx, field) + case "creationTimestamp": + return ec.fieldContext_ModelService_creationTimestamp(ctx, field) + case "updateTimestamp": + return ec.fieldContext_ModelService_updateTimestamp(ctx, field) + case "apiType": + return ec.fieldContext_ModelService_apiType(ctx, field) + case "llmResource": + return ec.fieldContext_ModelService_llmResource(ctx, field) + case "embedderResource": + return ec.fieldContext_ModelService_embedderResource(ctx, field) + case "resource": + return ec.fieldContext_ModelService_resource(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ModelService", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_ModelServiceQuery_checkModelService_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return fc, err + } + return fc, nil +} + func (ec *executionContext) _Mutation_hello(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Mutation_hello(ctx, field) if err != nil { @@ -20631,6 +20748,8 @@ func (ec *executionContext) fieldContext_Query_ModelService(ctx context.Context, return ec.fieldContext_ModelServiceQuery_getModelService(ctx, field) case "listModelServices": return ec.fieldContext_ModelServiceQuery_listModelServices(ctx, field) + case "checkModelService": + return ec.fieldContext_ModelServiceQuery_checkModelService(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type ModelServiceQuery", field.Name) }, @@ -32725,6 +32844,42 @@ func (ec *executionContext) _ModelServiceQuery(ctx context.Context, sel ast.Sele continue } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) + case "checkModelService": + field := field + + innerFunc := func(ctx context.Context, fs *graphql.FieldSet) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._ModelServiceQuery_checkModelService(ctx, field, obj) + if res == graphql.Null { + atomic.AddUint32(&fs.Invalids, 1) + } + return res + } + + if field.Deferrable != nil { + dfs, ok := deferred[field.Deferrable.Label] + di := 0 + if ok { + dfs.AddField(field) + di = len(dfs.Values) - 1 + } else { + dfs = graphql.NewFieldSet([]graphql.CollectedField{field}) + deferred[field.Deferrable.Label] = dfs + } + dfs.Concurrently(di, func(ctx context.Context) graphql.Marshaler { + return innerFunc(ctx, dfs) + }) + + // don't run the out.Concurrently() call below + out.Values[i] = graphql.Null + continue + } + out.Concurrently(i, func(ctx context.Context) graphql.Marshaler { return innerFunc(ctx, out) }) default: panic("unknown field " + strconv.Quote(field.Name)) diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index adb2b00a2..fafe01c0c 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -1018,6 +1018,7 @@ type ModelServiceMutation struct { type ModelServiceQuery struct { GetModelService *ModelService `json:"getModelService,omitempty"` ListModelServices PaginatedResult `json:"listModelServices"` + CheckModelService ModelService `json:"checkModelService"` } // 对象存储的使用信息 diff --git a/apiserver/graph/impl/modelservice.resolvers.go b/apiserver/graph/impl/modelservice.resolvers.go index 338c4d688..a61b31d24 100644 --- a/apiserver/graph/impl/modelservice.resolvers.go +++ b/apiserver/graph/impl/modelservice.resolvers.go @@ -56,6 +56,15 @@ func (r *modelServiceQueryResolver) ListModelServices(ctx context.Context, obj * return modelservice.ListModelServices(ctx, c, input) } +// CheckModelService is the resolver for the checkModelService field. +func (r *modelServiceQueryResolver) CheckModelService(ctx context.Context, obj *generated.ModelServiceQuery, input generated.CreateModelServiceInput) (*generated.ModelService, error) { + c, err := getClientFromCtx(ctx) + if err != nil { + return nil, err + } + return modelservice.CheckModelService(ctx, c, input) +} + // ModelService is the resolver for the ModelService field. func (r *mutationResolver) ModelService(ctx context.Context) (*generated.ModelServiceMutation, error) { return &generated.ModelServiceMutation{}, nil diff --git a/apiserver/graph/schema/modelservice.gql b/apiserver/graph/schema/modelservice.gql index b9828aa40..c2a57ab22 100644 --- a/apiserver/graph/schema/modelservice.gql +++ b/apiserver/graph/schema/modelservice.gql @@ -43,6 +43,11 @@ mutation createModelService($input: CreateModelServiceInput!) { status message } + resource { + cpu + nvidiaGPU + memory + } } } } @@ -92,6 +97,11 @@ mutation updateModelService($input: UpdateModelServiceInput) { status message } + resource { + cpu + nvidiaGPU + memory + } } } } @@ -215,4 +225,15 @@ query listModelServices($input: ListModelService) { } } } +} + +query checkModelService($input: CreateModelServiceInput!) { + ModelService { + checkModelService(input: $input) { + name + namespace + apiType + description + } + } } \ No newline at end of file diff --git a/apiserver/graph/schema/modelservice.graphqls b/apiserver/graph/schema/modelservice.graphqls index f73f37fdd..e5c689de0 100644 --- a/apiserver/graph/schema/modelservice.graphqls +++ b/apiserver/graph/schema/modelservice.graphqls @@ -137,6 +137,7 @@ extend type Mutation { type ModelServiceQuery { getModelService(name: String!, namespace: String!, apiType: String!): ModelService listModelServices(input: ListModelService): PaginatedResult! + checkModelService(input: CreateModelServiceInput!): ModelService! } extend type Query { diff --git a/apiserver/pkg/embedder/embedder.go b/apiserver/pkg/embedder/embedder.go index f8790b0fe..3ff0118ca 100644 --- a/apiserver/pkg/embedder/embedder.go +++ b/apiserver/pkg/embedder/embedder.go @@ -125,6 +125,7 @@ func CreateEmbedder(ctx context.Context, c dynamic.Interface, input generated.Cr // create auth secret secret := common.MakeAuthSecretName(embedder.Name, "embedder") err := common.MakeAuthSecret(ctx, c, generated.TypedObjectReferenceInput{ + Kind: "Secret", Name: secret, Namespace: &input.Namespace, }, input.Endpointinput.Auth, nil) @@ -153,6 +154,7 @@ func CreateEmbedder(ctx context.Context, c dynamic.Interface, input generated.Cr // user obj as the owner secret := common.MakeAuthSecretName(embedder.Name, "embedder") err := common.MakeAuthSecret(ctx, c, generated.TypedObjectReferenceInput{ + Kind: "Secret", Name: secret, Namespace: &input.Namespace, }, input.Endpointinput.Auth, obj) diff --git a/apiserver/pkg/llm/llm.go b/apiserver/pkg/llm/llm.go index 86cf1299d..d87139e7f 100644 --- a/apiserver/pkg/llm/llm.go +++ b/apiserver/pkg/llm/llm.go @@ -200,6 +200,7 @@ func CreateLLM(ctx context.Context, c dynamic.Interface, input generated.CreateL if input.Endpointinput.Auth != nil { secret := common.MakeAuthSecretName(llm.Name, "llm") err := common.MakeAuthSecret(ctx, c, generated.TypedObjectReferenceInput{ + Kind: "Secret", Name: secret, Namespace: &input.Namespace, }, input.Endpointinput.Auth, nil) @@ -218,7 +219,7 @@ func CreateLLM(ctx context.Context, c dynamic.Interface, input generated.CreateL return nil, err } - obj, err := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "LLM"}). + obj, err := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "llms"}). Namespace(input.Namespace).Create(ctx, &unstructured.Unstructured{Object: unstructuredLLM}, metav1.CreateOptions{}) if err != nil { return nil, err @@ -229,6 +230,7 @@ func CreateLLM(ctx context.Context, c dynamic.Interface, input generated.CreateL // user obj as the owner secret := common.MakeAuthSecretName(llm.Name, "LLM") err := common.MakeAuthSecret(ctx, c, generated.TypedObjectReferenceInput{ + Kind: "Secret", Name: secret, Namespace: &input.Namespace, }, input.Endpointinput.Auth, obj) @@ -271,7 +273,7 @@ func DeleteLLMs(ctx context.Context, c dynamic.Interface, input *generated.Delet labelSelector = *input.LabelSelector } - resource := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "LLMs"}) + resource := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "llms"}) if name != "" { err := resource.Namespace(input.Namespace).Delete(ctx, name, metav1.DeleteOptions{}) if err != nil { diff --git a/apiserver/pkg/modelservice/modelservice.go b/apiserver/pkg/modelservice/modelservice.go index a89390cca..c73f3ff12 100644 --- a/apiserver/pkg/modelservice/modelservice.go +++ b/apiserver/pkg/modelservice/modelservice.go @@ -18,6 +18,7 @@ package modelservice import ( "context" + "errors" "fmt" "strings" "time" @@ -29,6 +30,8 @@ import ( "github.com/kubeagi/arcadia/apiserver/pkg/embedder" "github.com/kubeagi/arcadia/apiserver/pkg/llm" "github.com/kubeagi/arcadia/apiserver/pkg/worker" + "github.com/kubeagi/arcadia/pkg/llms/openai" + "github.com/kubeagi/arcadia/pkg/llms/zhipuai" ) func CreateModelService(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (*generated.ModelService, error) { @@ -96,7 +99,7 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate ms := generated.ModelService{ // fulfill all params - // TBD: ID, Creator + // TBD: ID, Creator, Resource Name: input.Name, Namespace: input.Namespace, DisplayName: &displayName, @@ -162,20 +165,23 @@ func UpdateModelService(ctx context.Context, c dynamic.Interface, input generate } func DeleteModelService(ctx context.Context, c dynamic.Interface, input *generated.DeleteCommonInput) (*string, error) { - _, err := embedder.DeleteEmbedders(ctx, c, input) - if err != nil { - return nil, err + var errText string + _, err1 := embedder.DeleteEmbedders(ctx, c, input) + if err1 != nil { + errText += "embedder: " + err1.Error() } - _, err = llm.DeleteLLMs(ctx, c, input) - if err != nil { - return nil, err + _, err2 := llm.DeleteLLMs(ctx, c, input) + if err2 != nil { + errText += " llm:" + err2.Error() + } + if errText != "" { + return nil, errors.New("error occurred during deleting: " + errText) } return nil, nil } var ( fixedPage = 1 - // because the data is to be paged, no parameters are provided, // and the default return is 10. Getting modelserve needs to get all the llms and embedding, // so a larger pageSize is provided here. @@ -540,3 +546,63 @@ func GetModelService(ctx context.Context, c dynamic.Interface, name, namespace, } return ms, nil } + +var ( + ErrWrongAuthFormat = errors.New("wrong auth format, auth[\"apikey\"] should be string") + ErrNoAuthProvided = errors.New("no auth provided") + ErrNoAPIKeyProvided = errors.New("no apiKey provided") +) + +func CheckModelService(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (*generated.ModelService, error) { + var err error + if input.Endpoint.Auth != nil { + var info string + if input.Endpoint.Auth["apiKey"] == nil { + return nil, ErrNoAPIKeyProvided + } + if _, ok := input.Endpoint.Auth["apiKey"].(string); !ok { + return nil, ErrWrongAuthFormat + } + + switch *input.APIType { + case "openai": + info, err = checkOpenAI(ctx, c, input) + case "zhipuai": + info, err = checkZhipuAI(ctx, c, input) + default: + err = fmt.Errorf("not support api type %s", *input.APIType) + } + + if err != nil { + return nil, err + } + return &generated.ModelService{ + // TODO: implement a ‘status' field for ModelService as a better place to store info instead of Description + Name: input.Name, + Namespace: input.Namespace, + APIType: input.APIType, + Description: &info, + }, nil + } + return nil, ErrNoAuthProvided +} + +func checkOpenAI(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (string, error) { + apiKey := input.Endpoint.Auth["apiKey"].(string) + client := openai.NewOpenAI(apiKey, input.Endpoint.URL) + res, err := client.Validate() + if err != nil { + return "", err + } + return res.String(), nil +} + +func checkZhipuAI(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (string, error) { + apiKey := input.Endpoint.Auth["apiKey"].(string) + client := zhipuai.NewZhiPuAI(apiKey) + res, err := client.Validate() + if err != nil { + return "", err + } + return res.String(), nil +} diff --git a/controllers/llm_controller.go b/controllers/llm_controller.go index cf2f7200f..45fd534ab 100644 --- a/controllers/llm_controller.go +++ b/controllers/llm_controller.go @@ -256,5 +256,13 @@ func (llm LLMPredicates) Update(ue event.UpdateEvent) bool { oldLLM := ue.ObjectOld.(*arcadiav1alpha1.LLM) newLLM := ue.ObjectNew.(*arcadiav1alpha1.LLM) - return !reflect.DeepEqual(oldLLM.Spec, newLLM.Spec) + return !reflect.DeepEqual(oldLLM.Spec, newLLM.Spec) || newLLM.DeletionTimestamp != nil +} + +func (llm LLMPredicates) Delete(de event.DeleteEvent) bool { + return true +} + +func (llm LLMPredicates) Generic(ge event.GenericEvent) bool { + return true } diff --git a/gqlgen.yaml b/gqlgen.yaml index 56e534540..da0ba75c5 100644 --- a/gqlgen.yaml +++ b/gqlgen.yaml @@ -252,3 +252,5 @@ models: resolver: true listModelServices: resolver: true + checkModelService: + resolver: true