diff --git a/api/base/v1alpha1/embedder.go b/api/base/v1alpha1/embedder.go index 8e74f3d9e..e4b92be97 100644 --- a/api/base/v1alpha1/embedder.go +++ b/api/base/v1alpha1/embedder.go @@ -60,6 +60,12 @@ func (e Embedder) Get3rdPartyModels() []string { if e.Spec.Provider.GetType() != ProviderType3rdParty { return []string{} } + + // if models(customized) are provided,then return it + if e.Spec.Models != nil && len(e.Spec.Models) != 0 { + return e.Spec.Models + } + switch e.Spec.Type { case embeddings.ZhiPuAI: return embeddings.ZhiPuAIModels diff --git a/api/base/v1alpha1/embedder_types.go b/api/base/v1alpha1/embedder_types.go index f017ec992..cb5ccd1dd 100644 --- a/api/base/v1alpha1/embedder_types.go +++ b/api/base/v1alpha1/embedder_types.go @@ -34,6 +34,10 @@ type EmbedderSpec struct { // Provider defines the provider info which provide this embedder service Provider `json:"provider,omitempty"` + + // Models provided by this LLM + // If not set,we will use default model list based on LLMType + Models []string `json:"models,omitempty"` } // EmbeddingsStatus defines the observed state of Embedder diff --git a/api/base/v1alpha1/llm.go b/api/base/v1alpha1/llm.go index 13973eb8b..542e1249e 100644 --- a/api/base/v1alpha1/llm.go +++ b/api/base/v1alpha1/llm.go @@ -70,6 +70,12 @@ func (llm LLM) Get3rdPartyModels() []string { if llm.Spec.Provider.GetType() != ProviderType3rdParty { return []string{} } + + // if models(customized) are provided,then return it + if llm.Spec.Models != nil && len(llm.Spec.Models) != 0 { + return llm.Spec.Models + } + switch llm.Spec.Type { case llms.ZhiPuAI: return llms.ZhiPuAIModels diff --git a/api/base/v1alpha1/llm_types.go b/api/base/v1alpha1/llm_types.go index 26ffe1ab6..adbe2f411 100644 --- a/api/base/v1alpha1/llm_types.go +++ b/api/base/v1alpha1/llm_types.go @@ -31,6 +31,10 @@ type LLMSpec struct { // Provider defines the provider info which provide this llm service Provider `json:"provider,omitempty"` + + // Models provided by this LLM + // If not set,we will use default model list based on LLMType + Models []string `json:"models,omitempty"` } // LLMStatus defines the observed state of LLM diff --git a/api/base/v1alpha1/zz_generated.deepcopy.go b/api/base/v1alpha1/zz_generated.deepcopy.go index d85418489..8d870cce1 100644 --- a/api/base/v1alpha1/zz_generated.deepcopy.go +++ b/api/base/v1alpha1/zz_generated.deepcopy.go @@ -451,6 +451,11 @@ func (in *EmbedderSpec) DeepCopyInto(out *EmbedderSpec) { *out = *in out.CommonSpec = in.CommonSpec in.Provider.DeepCopyInto(&out.Provider) + if in.Models != nil { + in, out := &in.Models, &out.Models + *out = make([]string, len(*in)) + copy(*out, *in) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new EmbedderSpec. @@ -769,6 +774,11 @@ func (in *LLMSpec) DeepCopyInto(out *LLMSpec) { *out = *in out.CommonSpec = in.CommonSpec in.Provider.DeepCopyInto(&out.Provider) + if in.Models != nil { + in, out := &in.Models, &out.Models + *out = make([]string, len(*in)) + copy(*out, *in) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new LLMSpec. diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index bb5f30aa5..36b8775fd 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -428,8 +428,10 @@ type ComplexityRoot struct { Creator func(childComplexity int) int Description func(childComplexity int) int DisplayName func(childComplexity int) int + EmbeddingModels func(childComplexity int) int ID func(childComplexity int) int Labels func(childComplexity int) int + LlmModels func(childComplexity int) int Message func(childComplexity int) int Name func(childComplexity int) int Namespace func(childComplexity int) int @@ -2633,6 +2635,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.DisplayName(childComplexity), true + case "ModelService.embeddingModels": + if e.complexity.ModelService.EmbeddingModels == nil { + break + } + + return e.complexity.ModelService.EmbeddingModels(childComplexity), true + case "ModelService.id": if e.complexity.ModelService.ID == nil { break @@ -2647,6 +2656,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ModelService.Labels(childComplexity), true + case "ModelService.llmModels": + if e.complexity.ModelService.LlmModels == nil { + break + } + + return e.complexity.ModelService.LlmModels(childComplexity), true + case "ModelService.message": if e.complexity.ModelService.Message == nil { break @@ -5249,6 +5265,18 @@ extend type Query { apiType: String + """ + 模型服务的大语言模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + llmModels: [String!] + + """ + 模型服务的Embedding模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + embeddingModels: [String!] + """ 服务地址: 仅针对第三方模型服务 """ @@ -5301,6 +5329,19 @@ input CreateModelServiceInput { 模型服务终端输入 """ endpoint: EndpointInput! + + + """ + 模型服务的大语言模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + llmModels: [String!] + + """ + 模型服务的Embedding模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + embeddingModels: [String!] } input UpdateModelServiceInput { @@ -5334,6 +5375,18 @@ input UpdateModelServiceInput { 模型服务终端输入 """ endpoint: EndpointInput! + + """ + 模型服务的大语言模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + llmModels: [String!] + + """ + 模型服务的Embedding模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + embeddingModels: [String!] } input ListModelServiceInput { @@ -18702,6 +18755,88 @@ func (ec *executionContext) fieldContext_ModelService_apiType(ctx context.Contex return fc, nil } +func (ec *executionContext) _ModelService_llmModels(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_llmModels(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.LlmModels, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]string) + fc.Result = res + return ec.marshalOString2ᚕstringᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ModelService_llmModels(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ModelService", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _ModelService_embeddingModels(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ModelService_embeddingModels(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.EmbeddingModels, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]string) + fc.Result = res + return ec.marshalOString2ᚕstringᚄ(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ModelService_embeddingModels(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ModelService", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _ModelService_baseUrl(ctx context.Context, field graphql.CollectedField, obj *ModelService) (ret graphql.Marshaler) { fc, err := ec.fieldContext_ModelService_baseUrl(ctx, field) if err != nil { @@ -18893,6 +19028,10 @@ func (ec *executionContext) fieldContext_ModelServiceMutation_createModelService return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) + case "llmModels": + return ec.fieldContext_ModelService_llmModels(ctx, field) + case "embeddingModels": + return ec.fieldContext_ModelService_embeddingModels(ctx, field) case "baseUrl": return ec.fieldContext_ModelService_baseUrl(ctx, field) case "status": @@ -18982,6 +19121,10 @@ func (ec *executionContext) fieldContext_ModelServiceMutation_updateModelService return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) + case "llmModels": + return ec.fieldContext_ModelService_llmModels(ctx, field) + case "embeddingModels": + return ec.fieldContext_ModelService_embeddingModels(ctx, field) case "baseUrl": return ec.fieldContext_ModelService_baseUrl(ctx, field) case "status": @@ -19123,6 +19266,10 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_getModelService(ctx c return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) + case "llmModels": + return ec.fieldContext_ModelService_llmModels(ctx, field) + case "embeddingModels": + return ec.fieldContext_ModelService_embeddingModels(ctx, field) case "baseUrl": return ec.fieldContext_ModelService_baseUrl(ctx, field) case "status": @@ -19279,6 +19426,10 @@ func (ec *executionContext) fieldContext_ModelServiceQuery_checkModelService(ctx return ec.fieldContext_ModelService_types(ctx, field) case "apiType": return ec.fieldContext_ModelService_apiType(ctx, field) + case "llmModels": + return ec.fieldContext_ModelService_llmModels(ctx, field) + case "embeddingModels": + return ec.fieldContext_ModelService_embeddingModels(ctx, field) case "baseUrl": return ec.fieldContext_ModelService_baseUrl(ctx, field) case "status": @@ -26642,7 +26793,7 @@ func (ec *executionContext) unmarshalInputCreateModelServiceInput(ctx context.Co asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "types", "apiType", "endpoint"} + fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "types", "apiType", "endpoint", "llmModels", "embeddingModels"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -26730,6 +26881,24 @@ func (ec *executionContext) unmarshalInputCreateModelServiceInput(ctx context.Co return it, err } it.Endpoint = data + case "llmModels": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("llmModels")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.LlmModels = data + case "embeddingModels": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("embeddingModels")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.EmbeddingModels = data } } @@ -28790,7 +28959,7 @@ func (ec *executionContext) unmarshalInputUpdateModelServiceInput(ctx context.Co asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "types", "apiType", "endpoint"} + fieldsInOrder := [...]string{"name", "namespace", "labels", "annotations", "displayName", "description", "types", "apiType", "endpoint", "llmModels", "embeddingModels"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -28878,6 +29047,24 @@ func (ec *executionContext) unmarshalInputUpdateModelServiceInput(ctx context.Co return it, err } it.Endpoint = data + case "llmModels": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("llmModels")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.LlmModels = data + case "embeddingModels": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("embeddingModels")) + data, err := ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + it.EmbeddingModels = data } } @@ -32616,6 +32803,10 @@ func (ec *executionContext) _ModelService(ctx context.Context, sel ast.Selection out.Values[i] = ec._ModelService_types(ctx, field, obj) case "apiType": out.Values[i] = ec._ModelService_apiType(ctx, field, obj) + case "llmModels": + out.Values[i] = ec._ModelService_llmModels(ctx, field, obj) + case "embeddingModels": + out.Values[i] = ec._ModelService_embeddingModels(ctx, field, obj) case "baseUrl": out.Values[i] = ec._ModelService_baseUrl(ctx, field, obj) if out.Values[i] == graphql.Null { diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 90f5619cb..fa7fc7f15 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -288,6 +288,12 @@ type CreateModelServiceInput struct { APIType *string `json:"apiType,omitempty"` // 模型服务终端输入 Endpoint EndpointInput `json:"endpoint"` + // 模型服务的大语言模型列表 + // 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + LlmModels []string `json:"llmModels,omitempty"` + // 模型服务的Embedding模型列表 + // 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + EmbeddingModels []string `json:"embeddingModels,omitempty"` } type CreateVersionedDatasetInput struct { @@ -1018,6 +1024,12 @@ type ModelService struct { // 模型服务 API 类型 // 规则:支持 openai, zhipuai 两种类型 APIType *string `json:"apiType,omitempty"` + // 模型服务的大语言模型列表 + // 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + LlmModels []string `json:"llmModels,omitempty"` + // 模型服务的Embedding模型列表 + // 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + EmbeddingModels []string `json:"embeddingModels,omitempty"` // 服务地址: 仅针对第三方模型服务 BaseURL string `json:"baseUrl"` // 状态 @@ -1276,6 +1288,12 @@ type UpdateModelServiceInput struct { APIType *string `json:"apiType,omitempty"` // 模型服务终端输入 Endpoint EndpointInput `json:"endpoint"` + // 模型服务的大语言模型列表 + // 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + LlmModels []string `json:"llmModels,omitempty"` + // 模型服务的Embedding模型列表 + // 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + EmbeddingModels []string `json:"embeddingModels,omitempty"` } type UpdateVersionedDatasetInput struct { diff --git a/apiserver/graph/schema/modelservice.gql b/apiserver/graph/schema/modelservice.gql index 1dc1c5c3d..2e3b3a868 100644 --- a/apiserver/graph/schema/modelservice.gql +++ b/apiserver/graph/schema/modelservice.gql @@ -12,6 +12,7 @@ mutation createModelService($input: CreateModelServiceInput!) { providerType types apiType + models creationTimestamp updateTimestamp status @@ -35,6 +36,7 @@ mutation updateModelService($input: UpdateModelServiceInput) { providerType types apiType + models creationTimestamp updateTimestamp status @@ -64,6 +66,7 @@ query getModelService($name: String!, $namespace: String!) { providerType types apiType + models creationTimestamp updateTimestamp status @@ -92,6 +95,7 @@ query listModelServices($input: ListModelServiceInput) { providerType types apiType + models creationTimestamp updateTimestamp status diff --git a/apiserver/graph/schema/modelservice.graphqls b/apiserver/graph/schema/modelservice.graphqls index 081498e13..528727fc2 100644 --- a/apiserver/graph/schema/modelservice.graphqls +++ b/apiserver/graph/schema/modelservice.graphqls @@ -38,6 +38,18 @@ type ModelService { apiType: String + """ + 模型服务的大语言模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + llmModels: [String!] + + """ + 模型服务的Embedding模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + embeddingModels: [String!] + """ 服务地址: 仅针对第三方模型服务 """ @@ -90,6 +102,19 @@ input CreateModelServiceInput { 模型服务终端输入 """ endpoint: EndpointInput! + + + """ + 模型服务的大语言模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + llmModels: [String!] + + """ + 模型服务的Embedding模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + embeddingModels: [String!] } input UpdateModelServiceInput { @@ -123,6 +148,18 @@ input UpdateModelServiceInput { 模型服务终端输入 """ endpoint: EndpointInput! + + """ + 模型服务的大语言模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + llmModels: [String!] + + """ + 模型服务的Embedding模型列表 + 规则;如果不填或者为空,则按照模型的API类型获取默认的模型列表 + """ + embeddingModels: [String!] } input ListModelServiceInput { diff --git a/apiserver/pkg/modelservice/helper.go b/apiserver/pkg/modelservice/helper.go index 7d6b8b71f..0d44f13a3 100644 --- a/apiserver/pkg/modelservice/helper.go +++ b/apiserver/pkg/modelservice/helper.go @@ -42,6 +42,8 @@ func Embedder2ModelService(embedder *generated.Embedder) *generated.ModelService APIType: embedder.Type, // BaseURL of this Embedder BaseURL: embedder.BaseURL, + // EmbeddingModels of this modelservice + EmbeddingModels: embedder.Models, // Statuds of this model service Status: embedder.Status, Message: embedder.Message, @@ -68,8 +70,10 @@ func LLM2ModelService(llm *generated.Llm) *generated.ModelService { Types: &common.ModelTypeLLM, // APIType of this modelservice APIType: llm.Type, - // BaseURL of this Embedder + // BaseURL of this modelservice BaseURL: llm.BaseURL, + // EmbeddingModels of this modelservice + LlmModels: llm.Models, // Statuds of this model service Status: llm.Status, Message: llm.Message, diff --git a/apiserver/pkg/modelservice/modelservice.go b/apiserver/pkg/modelservice/modelservice.go index 900b67054..32d83f1a7 100644 --- a/apiserver/pkg/modelservice/modelservice.go +++ b/apiserver/pkg/modelservice/modelservice.go @@ -24,6 +24,7 @@ import ( "strings" "time" + "github.com/tmc/langchaingo/llms" "k8s.io/client-go/dynamic" "github.com/kubeagi/arcadia/apiserver/graph/generated" @@ -214,10 +215,13 @@ func ListModelServices(ctx context.Context, c dynamic.Interface, input *generate serviceMapList := make(map[string]*generated.ModelService) for _, node := range append(llmList.Nodes, embedderList.Nodes...) { ms, _ := node.(*generated.ModelService) - _, ok := serviceMapList[ms.Name] + curr, ok := serviceMapList[ms.Name] // if llm & embedder has same name,we treat it as `ModelTypeAll(llm,embedding)` if ok { ms.Types = &common.ModelTypeAll + // combine models provided by this model service + ms.LlmModels = append(ms.LlmModels, curr.LlmModels...) + ms.EmbeddingModels = append(ms.EmbeddingModels, curr.EmbeddingModels...) } serviceMapList[ms.Name] = ms } @@ -334,8 +338,13 @@ func CheckModelService(ctx context.Context, c dynamic.Interface, input generated func checkOpenAI(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (string, error) { apiKey := input.Endpoint.Auth["apiKey"].(string) - client := openai.NewOpenAI(apiKey, input.Endpoint.URL) - res, err := client.Validate() + client, err := openai.NewOpenAI(apiKey, input.Endpoint.URL) + if err != nil { + return "", err + } + + // TODO: able to validate openai models + res, err := client.Validate(ctx, llms.WithModel("")) if err != nil { return "", err } @@ -345,7 +354,7 @@ func checkOpenAI(ctx context.Context, c dynamic.Interface, input generated.Creat func checkZhipuAI(ctx context.Context, c dynamic.Interface, input generated.CreateModelServiceInput) (string, error) { apiKey := input.Endpoint.Auth["apiKey"].(string) client := zhipuai.NewZhiPuAI(apiKey) - res, err := client.Validate() + res, err := client.Validate(ctx) if err != nil { return "", err } diff --git a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_embedders.yaml b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_embedders.yaml index e79c0bcbb..36ac8b9fc 100644 --- a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_embedders.yaml +++ b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_embedders.yaml @@ -54,6 +54,12 @@ spec: displayName: description: DisplayName defines datasource display name type: string + models: + description: Models provided by this LLM If not set,we will use default + model list based on LLMType + items: + type: string + type: array provider: description: Provider defines the provider info which provide this embedder service diff --git a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_llms.yaml b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_llms.yaml index 687d6f233..37ba71b98 100644 --- a/config/crd/bases/arcadia.kubeagi.k8s.com.cn_llms.yaml +++ b/config/crd/bases/arcadia.kubeagi.k8s.com.cn_llms.yaml @@ -54,6 +54,12 @@ spec: displayName: description: DisplayName defines datasource display name type: string + models: + description: Models provided by this LLM If not set,we will use default + model list based on LLMType + items: + type: string + type: array provider: description: Provider defines the provider info which provide this llm service diff --git a/config/samples/arcadia_v1alpha1_embedder_fs.yaml b/config/samples/arcadia_v1alpha1_embedder_fs.yaml new file mode 100644 index 000000000..4fe04d0ac --- /dev/null +++ b/config/samples/arcadia_v1alpha1_embedder_fs.yaml @@ -0,0 +1,25 @@ +apiVersion: v1 +kind: Secret +metadata: + name: qwen-7b-chat-fs +type: Opaque +data: + apiKey: "MTZlZDcxYzcwMDE0NGFiMjIyMmI5YmEwZDFhMTBhZTUuUTljWVZtWWxmdjlnZGtDeQ==" # replace this with your API key +--- +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Embedder +metadata: + name: qwen-7b-chat-fs +spec: + displayName: 通义千问7B对话 + description: "这是一个对话模型服务,由Arcadia的Worker提供" + type: "openai" + models: + - a3e0c8a6-101c-4000-a1cd-d523ff7f521d + provider: + endpoint: + url: "http://fastchat-api.172.22.96.167.nip.io/v1" # replace this with your LLM URL(Zhipuai use predefined url https://open.bigmodel.cn/api/paas/v3/model-api) + authSecret: + kind: secret + name: qwen-7b-chat-fs + insecure: true diff --git a/config/samples/arcadia_v1alpha1_llm_fs.yaml b/config/samples/arcadia_v1alpha1_llm_fs.yaml new file mode 100644 index 000000000..621d88d03 --- /dev/null +++ b/config/samples/arcadia_v1alpha1_llm_fs.yaml @@ -0,0 +1,25 @@ +apiVersion: v1 +kind: Secret +metadata: + name: qwen-7b-chat-fs +type: Opaque +data: + apiKey: "MTZlZDcxYzcwMDE0NGFiMjIyMmI5YmEwZDFhMTBhZTUuUTljWVZtWWxmdjlnZGtDeQ==" # replace this with your API key +--- +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: LLM +metadata: + name: qwen-7b-chat-fs +spec: + displayName: 通义千问7B对话 + description: "这是一个对话模型服务,由Arcadia的Worker提供" + type: "openai" + models: + - a3e0c8a6-101c-4000-a1cd-d523ff7f521d + provider: + endpoint: + url: "http://fastchat-api.172.22.96.167.nip.io/v1" # replace this with your LLM URL(Zhipuai use predefined url https://open.bigmodel.cn/api/paas/v3/model-api) + authSecret: + kind: secret + name: qwen-7b-chat-fs + insecure: true diff --git a/controllers/embedder_controller.go b/controllers/embedder_controller.go index d488bad0f..d9dabbb4e 100644 --- a/controllers/embedder_controller.go +++ b/controllers/embedder_controller.go @@ -23,6 +23,7 @@ import ( "reflect" "github.com/go-logr/logr" + langchainllms "github.com/tmc/langchaingo/llms" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -184,17 +185,26 @@ func (r *EmbedderReconciler) check3rdPartyEmbedder(ctx context.Context, logger l return r.UpdateStatus(ctx, instance, nil, err) } + models := instance.Get3rdPartyModels() + if len(models) == 0 { + return r.UpdateStatus(ctx, instance, nil, errors.New("no models provided by this embedder")) + } + switch instance.Spec.Type { case embeddings.ZhiPuAI: embedClient := zhipuai.NewZhiPuAI(apiKey) - res, err := embedClient.Validate() + res, err := embedClient.Validate(ctx) if err != nil { return r.UpdateStatus(ctx, instance, nil, err) } msg = res.String() case embeddings.OpenAI: - embedClient := openai.NewOpenAI(apiKey, instance.Spec.Endpoint.URL) - res, err := embedClient.Validate() + embedClient, err := openai.NewOpenAI(apiKey, instance.Spec.Endpoint.URL) + if err != nil { + return r.UpdateStatus(ctx, instance, nil, err) + } + // validate againsthe 1st model + res, err := embedClient.Validate(ctx, langchainllms.WithModel(models[0])) if err != nil { return r.UpdateStatus(ctx, instance, nil, err) } diff --git a/controllers/llm_controller.go b/controllers/llm_controller.go index 45fd534ab..89c966566 100644 --- a/controllers/llm_controller.go +++ b/controllers/llm_controller.go @@ -23,6 +23,7 @@ import ( "reflect" "github.com/go-logr/logr" + langchainllms "github.com/tmc/langchaingo/llms" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -170,17 +171,26 @@ func (r *LLMReconciler) check3rdPartyLLM(ctx context.Context, logger logr.Logger return r.UpdateStatus(ctx, instance, nil, err) } + models := instance.Get3rdPartyModels() + if len(models) == 0 { + return r.UpdateStatus(ctx, instance, nil, errors.New("no models provided by this embedder")) + } + switch instance.Spec.Type { case llms.ZhiPuAI: - embedClient := zhipuai.NewZhiPuAI(apiKey) - res, err := embedClient.Validate() + llmClient := zhipuai.NewZhiPuAI(apiKey) + res, err := llmClient.Validate(ctx) if err != nil { return r.UpdateStatus(ctx, instance, nil, err) } msg = res.String() case llms.OpenAI: - embedClient := openai.NewOpenAI(apiKey, instance.Spec.Endpoint.URL) - res, err := embedClient.Validate() + llmClient, err := openai.NewOpenAI(apiKey, instance.Spec.Endpoint.URL) + if err != nil { + return r.UpdateStatus(ctx, instance, nil, err) + } + // validate againsthe 1st model + res, err := llmClient.Validate(ctx, langchainllms.WithModel(models[0])) if err != nil { return r.UpdateStatus(ctx, instance, nil, err) } diff --git a/controllers/prompt_controller.go b/controllers/prompt_controller.go index 1dc1313a6..6b7d6be8e 100644 --- a/controllers/prompt_controller.go +++ b/controllers/prompt_controller.go @@ -130,7 +130,10 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom llmClient = llmszhipuai.NewZhiPuAI(apiKey) callData = prompt.Spec.ZhiPuAIParams.Marshal() case llms.OpenAI: - llmClient = openai.NewOpenAI(apiKey, llm.Spec.Endpoint.URL) + llmClient, err = openai.NewOpenAI(apiKey, llm.Spec.Endpoint.URL) + if err != nil { + return r.UpdateStatus(ctx, prompt, nil, err) + } default: llmClient = llms.NewUnknowLLM() } diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 4f48b7008..f2d971bef 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.1.52 +version: 0.1.53 appVersion: "0.0.1" keywords: diff --git a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_embedders.yaml b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_embedders.yaml index e79c0bcbb..36ac8b9fc 100644 --- a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_embedders.yaml +++ b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_embedders.yaml @@ -54,6 +54,12 @@ spec: displayName: description: DisplayName defines datasource display name type: string + models: + description: Models provided by this LLM If not set,we will use default + model list based on LLMType + items: + type: string + type: array provider: description: Provider defines the provider info which provide this embedder service diff --git a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_llms.yaml b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_llms.yaml index 687d6f233..37ba71b98 100644 --- a/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_llms.yaml +++ b/deploy/charts/arcadia/crds/arcadia.kubeagi.k8s.com.cn_llms.yaml @@ -54,6 +54,12 @@ spec: displayName: description: DisplayName defines datasource display name type: string + models: + description: Models provided by this LLM If not set,we will use default + model list based on LLMType + items: + type: string + type: array provider: description: Provider defines the provider info which provide this llm service diff --git a/examples/chat_with_document/load.go b/examples/chat_with_document/load.go index 7d5915045..6edc4814d 100644 --- a/examples/chat_with_document/load.go +++ b/examples/chat_with_document/load.go @@ -19,11 +19,12 @@ package main import ( "context" "fmt" + "os" + "github.com/spf13/cobra" "github.com/tmc/langchaingo/documentloaders" "github.com/tmc/langchaingo/textsplitter" "github.com/tmc/langchaingo/vectorstores/chroma" - "os" zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai" "github.com/kubeagi/arcadia/pkg/llms/zhipuai" @@ -79,7 +80,7 @@ func runLoad(ctx context.Context) error { fmt.Println("Connecting platform...") z := zhipuai.NewZhiPuAI(apiKey) - _, err := z.Validate() + _, err := z.Validate(ctx) if err != nil { return fmt.Errorf("error validating ZhiPuAI api key: %s", err.Error()) } diff --git a/examples/chat_with_document/start.go b/examples/chat_with_document/start.go index 2d1fb50b0..f99d75496 100644 --- a/examples/chat_with_document/start.go +++ b/examples/chat_with_document/start.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "context" "fmt" "github.com/gofiber/fiber/v2" @@ -76,7 +77,7 @@ func run() error { fmt.Println("Connecting platform...") z := zhipuai.NewZhiPuAI(apiKey) - _, err := z.Validate() + _, err := z.Validate(context.TODO()) if err != nil { return fmt.Errorf("error validating ZhiPuAI api key: %s", err.Error()) } diff --git a/pkg/llms/dashscope/api.go b/pkg/llms/dashscope/api.go index adab44769..0d818a5f5 100644 --- a/pkg/llms/dashscope/api.go +++ b/pkg/llms/dashscope/api.go @@ -22,6 +22,8 @@ import ( "errors" "fmt" + langchainllms "github.com/tmc/langchaingo/llms" + "github.com/kubeagi/arcadia/pkg/llms" ) @@ -73,7 +75,7 @@ func (z *DashScope) Call(data []byte) (llms.Response, error) { return do(context.TODO(), DashScopeChatURL, z.apiKey, data, z.sse, false, params.Model) } -func (z *DashScope) Validate() (llms.Response, error) { +func (z *DashScope) Validate(ctx context.Context, options ...langchainllms.CallOption) (llms.Response, error) { return nil, errors.New("not implemented") } diff --git a/pkg/llms/llms.go b/pkg/llms/llms.go index 6a028e3d4..55afbfd4e 100644 --- a/pkg/llms/llms.go +++ b/pkg/llms/llms.go @@ -16,7 +16,12 @@ limitations under the License. package llms -import "errors" +import ( + "context" + "errors" + + langchainllms "github.com/tmc/langchaingo/llms" +) type LLMType string @@ -40,7 +45,7 @@ var ZhiPuAIModels = []string{ZhiPuAILite, ZhiPuAIStd, ZhiPuAIPro, ZhiPuAITurbo} type LLM interface { Type() LLMType Call([]byte) (Response, error) - Validate() (Response, error) + Validate(context.Context, ...langchainllms.CallOption) (Response, error) } type ModelParams interface { @@ -68,6 +73,6 @@ func (unknown UnknowLLM) Call(data []byte) (Response, error) { return nil, errors.New("unknown llm type") } -func (unknown UnknowLLM) Validate() (Response, error) { +func (unknown UnknowLLM) Validate(ctx context.Context, options ...langchainllms.CallOption) (Response, error) { return nil, errors.New("unknown llm type") } diff --git a/pkg/llms/openai/api.go b/pkg/llms/openai/api.go index 413c1b397..591624a23 100644 --- a/pkg/llms/openai/api.go +++ b/pkg/llms/openai/api.go @@ -17,11 +17,14 @@ limitations under the License. package openai import ( + "context" "errors" "fmt" - "net/http" "time" + langchainllms "github.com/tmc/langchaingo/llms" + langchainopenai "github.com/tmc/langchaingo/llms/openai" + "github.com/kubeagi/arcadia/pkg/llms" ) @@ -37,14 +40,20 @@ type OpenAI struct { baseURL string } -func NewOpenAI(apiKey string, baseURL string) *OpenAI { +func NewOpenAI(apiKey string, baseURL string) (*OpenAI, error) { if baseURL == "" { baseURL = OpenaiModelAPIURL } + + if apiKey == "" { + // TODO: maybe we should consider local pseudo-openAI LLM worker that doesn't require an apiKey? + return nil, fmt.Errorf("auth is empty") + } + return &OpenAI{ apiKey: apiKey, baseURL: baseURL, - } + }, nil } func (o OpenAI) Type() llms.LLMType { @@ -55,43 +64,26 @@ func (o *OpenAI) Call(data []byte) (llms.Response, error) { return nil, errors.New("not implemented yet") } -func (o *OpenAI) Validate() (llms.Response, error) { - // Validate OpenAI type CRD LLM Instance - // instance.Spec.URL should be like "https://api.openai.com/" - - if o.apiKey == "" { - // TODO: maybe we should consider local pseudo-openAI LLM worker that doesn't require an apiKey? - return nil, fmt.Errorf("auth is empty") - } - - testURL := o.baseURL + "/models" - testAuth := "Bearer " + o.apiKey // openAI official requirement - - req, err := http.NewRequest("GET", testURL, nil) +// Validate OpenAI service +func (o *OpenAI) Validate(ctx context.Context, options ...langchainllms.CallOption) (llms.Response, error) { + // validate agains models + llm, err := langchainopenai.New( + langchainopenai.WithBaseURL(o.baseURL), + langchainopenai.WithToken(o.apiKey), + ) if err != nil { - return nil, err + return nil, fmt.Errorf("init openai client: %w", err) } - req.Header.Set("Authorization", testAuth) - req.Header.Set("Content-Type", "application/json") - - cli := &http.Client{} - resp, err := cli.Do(req) + resp, err := llm.Call(ctx, "Hello", options...) if err != nil { return nil, err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("returns unexpected status code: %d", resp.StatusCode) - } - // FIXME: response object - response, err := parseHTTPResponse(resp) - if err != nil { - return nil, err - } - return response, nil + return &Response{ + Code: 200, + Data: resp, + Msg: "", + Success: true, + }, nil } - -// TODO: Openai Model Object & Other definition diff --git a/pkg/llms/openai/object.go b/pkg/llms/openai/object.go deleted file mode 100644 index 2b67c15bc..000000000 --- a/pkg/llms/openai/object.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package openai - -// Chat is a chat completion response returned by model. -type Chat struct { - ID string `json:"id"` // A unique identifier for the chat completion - Object string `json:"object"` // The object type, which is always chat.completion - Created int `json:"created"` // A unix timestamp of when the chat completion was created. - Model string `json:"model"` // The model used for the chat completion. - Choices []Choice `json:"choices"` // A list of chat completion choices. Can be more than one if n is greater than 1. - Usage Usage `json:"usage"` // Usage statistics of the completion request. -} - -// ChatStream is a streamed chunk of a chat completion returned by model. -type ChatStream struct { - ID string `json:"id"` // A unique identifier for the chat completion. - Object string `json:"object"` // The object type, which is always chat.completion - Created int `json:"created"` // A unix timestamp of when the chat completion was created. - Model string `json:"model"` // The model used for the chat completion. - Choices []ChoiceStream `json:"choices"` // A list of chat completion choices. Can be more than one if n is greater than 1. -} - -type Choice struct { - Index int `json:"index"` // The index of the choice in the list of choices. - Message Message `json:"message"` // The completion message generated by the model. - FinishReason string `json:"finish_reason"` // The reason the model stopped generating tokens. This will be stop if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, or `function_call` if the model called a function. -} - -type ChoiceStream struct { - Index int `json:"index"` - Delta Delta `json:"delta"` - FinishReason string `json:"finish_reason"` -} - -// Message is a chat completion message generated by the model. -type Message struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - FunctionCall FunctionCall `json:"function_call,omitempty"` -} - -// FunctionCall is used when a message is calling a function generated by openAI model. -type FunctionCall struct { - Name string `json:"name"` // Name of the function. - Arguments string `json:"arguments"` // JSON format of the arguments. -} - -// Usage is the usage statistics of the completion request. -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// Delta is A chat completion delta generated by streamed model responses. -type Delta struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - FunctionCall FunctionCall `json:"function_call,omitempty"` -} diff --git a/pkg/llms/openai/response.go b/pkg/llms/openai/response.go index c840ea9f3..7b5a47764 100644 --- a/pkg/llms/openai/response.go +++ b/pkg/llms/openai/response.go @@ -18,15 +18,13 @@ package openai import ( "encoding/json" - "fmt" - "net/http" "github.com/kubeagi/arcadia/pkg/llms" ) type Response struct { Code int `json:"code"` - Data string `json:"data"` // JSON format of the returned data + Data string `json:"data"` Msg string `json:"msg"` Success bool `json:"success"` } @@ -50,17 +48,3 @@ func (response *Response) String() string { func (response *Response) Unmarshal(bytes []byte) error { return json.Unmarshal(bytes, response) } - -func parseHTTPResponse(resp *http.Response) (*Response, error) { - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("exception: %s", resp.Status) - } - - var data = new(Response) - err := json.NewDecoder(resp.Body).Decode(&data) - if err != nil { - return nil, err - } - - return data, nil -} diff --git a/pkg/llms/zhipuai/api.go b/pkg/llms/zhipuai/api.go index b18e2469f..99a284b0d 100644 --- a/pkg/llms/zhipuai/api.go +++ b/pkg/llms/zhipuai/api.go @@ -25,6 +25,7 @@ import ( "time" "github.com/r3labs/sse/v2" + langchainllms "github.com/tmc/langchaingo/llms" "k8s.io/klog/v2" "github.com/kubeagi/arcadia/pkg/llms" @@ -132,7 +133,8 @@ func (z *ZhiPuAI) SSEInvoke(params ModelParams, handler func(*sse.Event)) error return Stream(url, token, params, ZhipuaiModelDefaultTimeout, handler) } -func (z *ZhiPuAI) Validate() (llms.Response, error) { +// Validate zhipuai service agains CallOption +func (z *ZhiPuAI) Validate(ctx context.Context, options ...langchainllms.CallOption) (llms.Response, error) { url := BuildAPIURL(llms.ZhiPuAILite, ZhiPuAIInvoke) token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) if err != nil {