diff --git a/api/v1alpha1/llm.go b/api/v1alpha1/llm.go new file mode 100644 index 000000000..03c745add --- /dev/null +++ b/api/v1alpha1/llm.go @@ -0,0 +1,47 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + "context" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func (llm LLM) AuthAPIKey(ctx context.Context, c client.Client) (string, error) { + if llm.Spec.Auth == "" { + return "", nil + } + authSecret := &corev1.Secret{} + err := c.Get(ctx, types.NamespacedName{Name: llm.Spec.Auth, Namespace: llm.Namespace}, authSecret) + if err != nil { + return "", err + } + return string(authSecret.Data["apiKey"]), nil +} + +func (llmStatus LLMStatus) LLMReady() (string, bool) { + if len(llmStatus.Conditions) == 0 { + return "No conditions yet", false + } + if llmStatus.Conditions[0].Type != TypeReady || llmStatus.Conditions[0].Status != corev1.ConditionTrue { + return "Bad condition", false + } + return "", true +} diff --git a/config/samples/arcadia_v1alpha1_llm.yaml b/config/samples/arcadia_v1alpha1_llm.yaml index b8ec0863d..53aa41dbb 100644 --- a/config/samples/arcadia_v1alpha1_llm.yaml +++ b/config/samples/arcadia_v1alpha1_llm.yaml @@ -4,7 +4,7 @@ metadata: name: zhipuai type: Opaque data: - apiKey: "MjZiMmJjNTVmYWU0MDc1MjA1NWNhZGZjNDc5MmY5ZGUud2FnQTROSXdnNWFaSldobQ==" # replace this with your API key + apiKey: "ZmVkMDk2NWJjZTAxOTBmZjJiYzY4MWFjMzA2ZDVmM2QuZUlwN3NPWHJueG1XSnhPaw==" # replace this with your API key --- apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 kind: LLM diff --git a/controllers/llm_controller.go b/controllers/llm_controller.go index 26aac593a..5b8647d28 100644 --- a/controllers/llm_controller.go +++ b/controllers/llm_controller.go @@ -19,6 +19,7 @@ package controllers import ( "context" "fmt" + "github.com/go-logr/logr" "github.com/kubeagi/arcadia/pkg/llms" "github.com/kubeagi/arcadia/pkg/llms/zhipuai" @@ -26,7 +27,6 @@ import ( "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" @@ -94,12 +94,10 @@ func (r *LLMReconciler) CheckLLM(ctx context.Context, logger logr.Logger, instan var err error var response llms.Response - secret := &corev1.Secret{} - err = r.Get(ctx, types.NamespacedName{Name: instance.Spec.Auth, Namespace: instance.Namespace}, secret) + apiKey, err := instance.AuthAPIKey(ctx, r.Client) if err != nil { return err } - apiKey := string(secret.Data["apiKey"]) switch instance.Spec.Type { case llms.OpenAI: diff --git a/controllers/prompt_controller.go b/controllers/prompt_controller.go index c3dad58d3..2d614a7b0 100644 --- a/controllers/prompt_controller.go +++ b/controllers/prompt_controller.go @@ -88,18 +88,13 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom return err } - var apiKey string - if llm.Spec.Auth != "" { - authSecret := corev1.Secret{} - if err := r.Get(ctx, types.NamespacedName{Name: llm.Spec.Auth, Namespace: prompt.Namespace}, &authSecret); err != nil { - return err - } - apiKey = string(authSecret.Data["apiKey"]) + apiKey, err := llm.AuthAPIKey(ctx, r.Client) + if err != nil { + return err } // llm call var resp llms.Response - var err error switch llm.Spec.Type { case llms.ZhiPuAI: resp, err = llmszhipuai.NewZhiPuAI(apiKey).Call(*prompt.Spec.ZhiPuAIParams) diff --git a/pkg/llms/llms.go b/pkg/llms/llms.go index cc3f891c2..b42959857 100644 --- a/pkg/llms/llms.go +++ b/pkg/llms/llms.go @@ -23,6 +23,11 @@ const ( ZhiPuAI LLMType = "zhipuai" ) +type LLM interface { + Type() LLMType + Validate() (Response, error) +} + type Response interface { Type() LLMType String() string diff --git a/pkg/llms/openai/api.go b/pkg/llms/openai/api.go index febde3ee7..4775f4059 100644 --- a/pkg/llms/openai/api.go +++ b/pkg/llms/openai/api.go @@ -20,6 +20,8 @@ import ( "fmt" "net/http" "time" + + "github.com/kubeagi/arcadia/pkg/llms" ) const ( @@ -27,6 +29,8 @@ const ( OpenaiDefaultTimeout = 300 * time.Second ) +var _ llms.LLM = (*OpenAI)(nil) + type OpenAI struct { apiKey string } @@ -37,7 +41,11 @@ func NewOpenAI(auth string) *OpenAI { } } -func (o *OpenAI) Validate() (*Response, error) { +func (o OpenAI) Type() llms.LLMType { + return llms.OpenAI +} + +func (o *OpenAI) Validate() (llms.Response, error) { // Validate OpenAI type CRD LLM Instance // instance.Spec.URL should be like "https://api.openai.com/" diff --git a/pkg/llms/zhipuai/api.go b/pkg/llms/zhipuai/api.go index 3a1052ac0..d5d91f07b 100644 --- a/pkg/llms/zhipuai/api.go +++ b/pkg/llms/zhipuai/api.go @@ -23,6 +23,7 @@ import ( "fmt" "time" + "github.com/kubeagi/arcadia/pkg/llms" "github.com/r3labs/sse/v2" ) @@ -54,6 +55,8 @@ func BuildAPIURL(model Model, method Method) string { return fmt.Sprintf("%s/%s/%s", ZhipuaiModelAPIURL, model, method) } +var _ llms.LLM = (*ZhiPuAI)(nil) + type ZhiPuAI struct { apiKey string } @@ -64,6 +67,10 @@ func NewZhiPuAI(apiKey string) *ZhiPuAI { } } +func (z ZhiPuAI) Type() llms.LLMType { + return llms.ZhiPuAI +} + // Call wraps a common AI api call func (z *ZhiPuAI) Call(params ModelParams) (*Response, error) { switch params.Method { @@ -125,7 +132,7 @@ func (z *ZhiPuAI) SSEInvoke(params ModelParams, handler func(*sse.Event)) error return Stream(url, token, params, ZhipuaiModelDefaultTimeout, nil) } -func (z *ZhiPuAI) Validate() (*Response, error) { +func (z *ZhiPuAI) Validate() (llms.Response, error) { url := BuildAPIURL(ZhiPuAILite, ZhiPuAIInvoke) token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) if err != nil {