From 3d5751f49c5dd6ea57cffd672dadf75ba5f9ae4c Mon Sep 17 00:00:00 2001 From: bjwswang Date: Wed, 24 Jan 2024 03:49:23 +0000 Subject: [PATCH] feat: able to set/update models in LLM and Embedder Signed-off-by: bjwswang --- apiserver/graph/generated/generated.go | 64 ++++++++++++++++++++-- apiserver/graph/generated/models_gen.go | 8 +++ apiserver/graph/schema/embedder.graphqls | 10 ++++ apiserver/graph/schema/llm.graphqls | 10 ++++ apiserver/pkg/embedder/embedder.go | 8 ++- apiserver/pkg/llm/llm.go | 8 ++- apiserver/pkg/modelservice/modelservice.go | 62 +++++++++++---------- 7 files changed, 134 insertions(+), 36 deletions(-) diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index a5b398b00..52626ee34 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -4935,6 +4935,11 @@ input CreateEmbedderInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } input UpdateEmbedderInput { @@ -4961,6 +4966,11 @@ input UpdateEmbedderInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } type EmbedderQuery { @@ -5365,6 +5375,11 @@ input CreateLLMInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } input UpdateLLMInput { @@ -5390,6 +5405,11 @@ input UpdateLLMInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } type LLMQuery { @@ -28240,7 +28260,7 @@ func (ec *executionContext) unmarshalInputCreateEmbedderInput(ctx context.Contex asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type"} + fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type", "models"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -28319,6 +28339,15 @@ func (ec *executionContext) unmarshalInputCreateEmbedderInput(ctx context.Contex return it, err } it.Type = data + case "models": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("models")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.Models = data } } @@ -28433,7 +28462,7 @@ func (ec *executionContext) unmarshalInputCreateLLMInput(ctx context.Context, ob asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type"} + fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type", "models"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -28512,6 +28541,15 @@ func (ec *executionContext) unmarshalInputCreateLLMInput(ctx context.Context, ob return it, err } it.Type = data + case "models": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("models")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.Models = data } } @@ -30702,7 +30740,7 @@ func (ec *executionContext) unmarshalInputUpdateEmbedderInput(ctx context.Contex asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type"} + fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type", "models"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -30781,6 +30819,15 @@ func (ec *executionContext) unmarshalInputUpdateEmbedderInput(ctx context.Contex return it, err } it.Type = data + case "models": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("models")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.Models = data } } @@ -30877,7 +30924,7 @@ func (ec *executionContext) unmarshalInputUpdateLLMInput(ctx context.Context, ob asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type"} + fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "endpointinput", "type", "models"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -30956,6 +31003,15 @@ func (ec *executionContext) unmarshalInputUpdateLLMInput(ctx context.Context, ob return it, err } it.Type = data + case "models": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("models")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.Models = data } } diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 560a09474..b9ffa468e 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -217,6 +217,8 @@ type CreateEmbedderInput struct { // 向量化模型服务接口类型 // 规则: 目前支持 zhipuai,openai两种接口类型 Type *string `json:"type,omitempty"` + // 此LLM支持调用的模型列表 + Models []string `json:"models,omitempty"` } // 创建知识库的输入 @@ -259,6 +261,8 @@ type CreateLLMInput struct { // 模型服务接口类型 // 规则: 目前支持 zhipuai,openai两种接口类型 Type *string `json:"type,omitempty"` + // 此LLM支持调用的模型列表 + Models []string `json:"models,omitempty"` } // 创建模型的输入 @@ -1307,6 +1311,8 @@ type UpdateEmbedderInput struct { // 向量化模型服务接口类型 // 规则: 目前支持 zhipuai,openai两种接口类型 Type *string `json:"type,omitempty"` + // 此LLM支持调用的模型列表 + Models []string `json:"models,omitempty"` } // 知识库更新的输入 @@ -1345,6 +1351,8 @@ type UpdateLLMInput struct { // 模型服务接口类型 // 规则: 目前支持 zhipuai,openai两种接口类型 Type *string `json:"type,omitempty"` + // 此LLM支持调用的模型列表 + Models []string `json:"models,omitempty"` } // 模型更新的输入 diff --git a/apiserver/graph/schema/embedder.graphqls b/apiserver/graph/schema/embedder.graphqls index 8860df674..6d2a7315e 100644 --- a/apiserver/graph/schema/embedder.graphqls +++ b/apiserver/graph/schema/embedder.graphqls @@ -54,6 +54,11 @@ input CreateEmbedderInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } input UpdateEmbedderInput { @@ -80,6 +85,11 @@ input UpdateEmbedderInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } type EmbedderQuery { diff --git a/apiserver/graph/schema/llm.graphqls b/apiserver/graph/schema/llm.graphqls index 18fdd6c08..f68b9a7d7 100644 --- a/apiserver/graph/schema/llm.graphqls +++ b/apiserver/graph/schema/llm.graphqls @@ -54,6 +54,11 @@ input CreateLLMInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } input UpdateLLMInput { @@ -79,6 +84,11 @@ input UpdateLLMInput { 规则: 目前支持 zhipuai,openai两种接口类型 """ type: String + + """ + 此LLM支持调用的模型列表 + """ + models: [String!] } type LLMQuery { diff --git a/apiserver/pkg/embedder/embedder.go b/apiserver/pkg/embedder/embedder.go index be0c48957..5c258ec92 100644 --- a/apiserver/pkg/embedder/embedder.go +++ b/apiserver/pkg/embedder/embedder.go @@ -130,7 +130,8 @@ func CreateEmbedder(ctx context.Context, c dynamic.Interface, input generated.Cr URL: input.Endpointinput.URL, }, }, - Type: embeddings.EmbeddingType(servicetype), + Type: embeddings.EmbeddingType(servicetype), + Models: input.Models, }, } @@ -210,6 +211,11 @@ func UpdateEmbedder(ctx context.Context, c dynamic.Interface, input *generated.U updatedEmbedder.Spec.Type = embeddings.EmbeddingType(*input.Type) } + // update LLM's models if specified + if input.Models != nil { + updatedEmbedder.Spec.Models = input.Models + } + // Update endpoint if input.Endpointinput != nil { endpoint, err := common.MakeEndpoint(ctx, c, generated.TypedObjectReferenceInput{ diff --git a/apiserver/pkg/llm/llm.go b/apiserver/pkg/llm/llm.go index c54ffa1c8..1b7ea414f 100644 --- a/apiserver/pkg/llm/llm.go +++ b/apiserver/pkg/llm/llm.go @@ -185,7 +185,8 @@ func CreateLLM(ctx context.Context, c dynamic.Interface, input generated.CreateL URL: input.Endpointinput.URL, }, }, - Type: llms.LLMType(APIType), + Type: llms.LLMType(APIType), + Models: input.Models, }, } common.SetCreator(ctx, &llm.Spec.CommonSpec) @@ -266,6 +267,11 @@ func UpdateLLM(ctx context.Context, c dynamic.Interface, input *generated.Update updatedLLM.Spec.Type = llms.LLMType(*input.Type) } + // update LLM's models if specified + if input.Models != nil { + updatedLLM.Spec.Models = input.Models + } + // Update endpoint if input.Endpointinput != nil { endpoint, err := common.MakeEndpoint(ctx, c, generated.TypedObjectReferenceInput{ diff --git a/apiserver/pkg/modelservice/modelservice.go b/apiserver/pkg/modelservice/modelservice.go index 84e402f61..f07cf39a5 100644 --- a/apiserver/pkg/modelservice/modelservice.go +++ b/apiserver/pkg/modelservice/modelservice.go @@ -22,7 +22,6 @@ import ( "fmt" "sort" "strings" - "time" "github.com/tmc/langchaingo/llms" "k8s.io/client-go/dynamic" @@ -59,7 +58,7 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate } var modelSerivce = &generated.ModelService{} - + var llmModels, embeddingModels []string // Create LLM if serviceType contains llm if strings.Contains(serviceType, "llm") { llm, err := llm.CreateLLM(ctx, c, generated.CreateLLMInput{ @@ -71,10 +70,12 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate Annotations: input.Annotations, Type: &APIType, Endpointinput: input.Endpoint, + Models: input.LlmModels, }) if err != nil { return nil, err } + llmModels = llm.Models modelSerivce = LLM2ModelService(llm) } @@ -89,15 +90,19 @@ func CreateModelService(ctx context.Context, c dynamic.Interface, input generate Annotations: input.Annotations, Type: &APIType, Endpointinput: input.Endpoint, + Models: input.EmbeddingModels, }) if err != nil { return nil, err } - + embeddingModels = embedder.Models modelSerivce = Embedder2ModelService(embedder) } + // merge llm&embedder modelSerivce.Types = &serviceType + modelSerivce.LlmModels = llmModels + modelSerivce.EmbeddingModels = embeddingModels return modelSerivce, nil } @@ -150,6 +155,7 @@ func UpdateModelService(ctx context.Context, c dynamic.Interface, input *generat Annotations: newAnnotations, Type: &newAPIType, Endpointinput: &input.Endpoint, + Models: input.LlmModels, } updateEmbedderInput := generated.UpdateEmbedderInput{ @@ -161,15 +167,19 @@ func UpdateModelService(ctx context.Context, c dynamic.Interface, input *generat Annotations: newAnnotations, Type: &newAPIType, Endpointinput: &input.Endpoint, + Models: input.EmbeddingModels, } // TODO: codes to delete/create llm/embedding resource if input.Types is changed. For now it will not work. - + var updatedModelSerivce = &generated.ModelService{} + var llmModels, embeddingModels []string if strings.Contains(*ms.Types, "llm") { updatedLLM, err = llm.UpdateLLM(ctx, c, &updateLLMInput) if err != nil { return nil, errors.New("update LLM failed: " + err.Error()) } + llmModels = updatedLLM.Models + updatedModelSerivce = LLM2ModelService(updatedLLM) } if strings.Contains(*ms.Types, "embedding") { @@ -177,31 +187,16 @@ func UpdateModelService(ctx context.Context, c dynamic.Interface, input *generat if err != nil { return nil, errors.New("update embedding failed: " + err.Error()) } + embeddingModels = updatedEmbedder.Models + updatedModelSerivce = Embedder2ModelService(updatedEmbedder) } - var creationTimestamp, updateTimestamp *time.Time - - if updatedLLM != nil { - creationTimestamp = updatedLLM.CreationTimestamp - updateTimestamp = updatedLLM.UpdateTimestamp - } else if updatedEmbedder != nil { - creationTimestamp = updatedLLM.CreationTimestamp - updateTimestamp = updatedLLM.UpdateTimestamp - } + // merge llm&embedder + updatedModelSerivce.Types = ms.Types + updatedModelSerivce.LlmModels = llmModels + updatedModelSerivce.EmbeddingModels = embeddingModels - ds := &generated.ModelService{ - Name: input.Name, - Namespace: input.Namespace, - DisplayName: &newDisplayName, - Description: &newDescription, - Labels: newLabels, - Annotations: newAnnotations, - Types: ms.Types, - APIType: &newAPIType, - CreationTimestamp: creationTimestamp, - UpdateTimestamp: updateTimestamp, - } - return ds, nil + return updatedModelSerivce, nil } // DeleteModelService deletes a 3rd_party model service @@ -232,19 +227,26 @@ func DeleteModelService(ctx context.Context, c dynamic.Interface, input *generat // ReadModelService get a 3rd_party model service func ReadModelService(ctx context.Context, c dynamic.Interface, name string, namespace string) (*generated.ModelService, error) { var modelService = &generated.ModelService{} - + var serviceTypes []string + var llmModels, embeddingModels []string llm, err := llm.ReadLLM(ctx, c, name, namespace) if err == nil { + llmModels = llm.Models + serviceTypes = append(serviceTypes, common.ModelTypeLLM) modelService = LLM2ModelService(llm) } embedder, err := embedder.ReadEmbedder(ctx, c, name, namespace) if err == nil { + embeddingModels = embedder.Models + serviceTypes = append(serviceTypes, common.ModelTypeEmbedding) modelService = Embedder2ModelService(embedder) } - if llm != nil && embedder != nil { - modelService.Types = &common.ModelTypeAll - } + serviceTypeStr := strings.Join(serviceTypes, ",") + modelService.Types = &serviceTypeStr + modelService.LlmModels = llmModels + modelService.EmbeddingModels = embeddingModels + return modelService, nil }