Skip to content

Commit

Permalink
fix: optimize llm checks
Browse files Browse the repository at this point in the history
Signed-off-by: bjwswang <[email protected]>
  • Loading branch information
bjwswang committed Aug 30, 2023
1 parent c975869 commit d5a2bfb
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 15 deletions.
47 changes: 47 additions & 0 deletions api/v1alpha1/llm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2023 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package v1alpha1

import (
"context"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)

func (llm LLM) AuthAPIKey(ctx context.Context, c client.Client) (string, error) {
if llm.Spec.Auth == "" {
return "", nil
}
authSecret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{Name: llm.Spec.Auth, Namespace: llm.Namespace}, authSecret)
if err != nil {
return "", err
}
return string(authSecret.Data["apiKey"]), nil
}

func (llmStatus LLMStatus) LLMReady() (string, bool) {
if len(llmStatus.Conditions) == 0 {
return "No conditions yet", false
}
if llmStatus.Conditions[0].Type != TypeReady || llmStatus.Conditions[0].Status != corev1.ConditionTrue {
return "Bad condition", false
}
return "", true
}
2 changes: 1 addition & 1 deletion config/samples/arcadia_v1alpha1_llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ metadata:
name: zhipuai
type: Opaque
data:
apiKey: "MjZiMmJjNTVmYWU0MDc1MjA1NWNhZGZjNDc5MmY5ZGUud2FnQTROSXdnNWFaSldobQ==" # replace this with your API key
apiKey: "ZmVkMDk2NWJjZTAxOTBmZjJiYzY4MWFjMzA2ZDVmM2QuZUlwN3NPWHJueG1XSnhPaw==" # replace this with your API key
---
apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: LLM
Expand Down
6 changes: 2 additions & 4 deletions controllers/llm_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package controllers
import (
"context"
"fmt"

"github.com/go-logr/logr"
"github.com/kubeagi/arcadia/pkg/llms"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -94,12 +94,10 @@ func (r *LLMReconciler) CheckLLM(ctx context.Context, logger logr.Logger, instan
var err error
var response llms.Response

secret := &corev1.Secret{}
err = r.Get(ctx, types.NamespacedName{Name: instance.Spec.Auth, Namespace: instance.Namespace}, secret)
apiKey, err := instance.AuthAPIKey(ctx, r.Client)
if err != nil {
return err
}
apiKey := string(secret.Data["apiKey"])

switch instance.Spec.Type {
case llms.OpenAI:
Expand Down
11 changes: 3 additions & 8 deletions controllers/prompt_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,13 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom
return err
}

var apiKey string
if llm.Spec.Auth != "" {
authSecret := corev1.Secret{}
if err := r.Get(ctx, types.NamespacedName{Name: llm.Spec.Auth, Namespace: prompt.Namespace}, &authSecret); err != nil {
return err
}
apiKey = string(authSecret.Data["apiKey"])
apiKey, err := llm.AuthAPIKey(ctx, r.Client)
if err != nil {
return err
}

// llm call
var resp llms.Response
var err error
switch llm.Spec.Type {
case llms.ZhiPuAI:
resp, err = llmszhipuai.NewZhiPuAI(apiKey).Call(*prompt.Spec.ZhiPuAIParams)
Expand Down
5 changes: 5 additions & 0 deletions pkg/llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ const (
ZhiPuAI LLMType = "zhipuai"
)

type LLM interface {
Type() LLMType
Validate() (Response, error)
}

type Response interface {
Type() LLMType
String() string
Expand Down
10 changes: 9 additions & 1 deletion pkg/llms/openai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ import (
"fmt"
"net/http"
"time"

"github.com/kubeagi/arcadia/pkg/llms"
)

const (
OpenaiModelAPIURL = "https://api.openai.com/v1"
OpenaiDefaultTimeout = 300 * time.Second
)

var _ llms.LLM = (*OpenAI)(nil)

type OpenAI struct {
apiKey string
}
Expand All @@ -37,7 +41,11 @@ func NewOpenAI(auth string) *OpenAI {
}
}

func (o *OpenAI) Validate() (*Response, error) {
func (o OpenAI) Type() llms.LLMType {
return llms.OpenAI
}

func (o *OpenAI) Validate() (llms.Response, error) {
// Validate OpenAI type CRD LLM Instance
// instance.Spec.URL should be like "https://api.openai.com/"

Expand Down
9 changes: 8 additions & 1 deletion pkg/llms/zhipuai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"time"

"github.com/kubeagi/arcadia/pkg/llms"
"github.com/r3labs/sse/v2"
)

Expand Down Expand Up @@ -54,6 +55,8 @@ func BuildAPIURL(model Model, method Method) string {
return fmt.Sprintf("%s/%s/%s", ZhipuaiModelAPIURL, model, method)
}

var _ llms.LLM = (*ZhiPuAI)(nil)

type ZhiPuAI struct {
apiKey string
}
Expand All @@ -64,6 +67,10 @@ func NewZhiPuAI(apiKey string) *ZhiPuAI {
}
}

func (z ZhiPuAI) Type() llms.LLMType {
return llms.ZhiPuAI
}

// Call wraps a common AI api call
func (z *ZhiPuAI) Call(params ModelParams) (*Response, error) {
switch params.Method {
Expand Down Expand Up @@ -125,7 +132,7 @@ func (z *ZhiPuAI) SSEInvoke(params ModelParams, handler func(*sse.Event)) error
return Stream(url, token, params, ZhipuaiModelDefaultTimeout, nil)
}

func (z *ZhiPuAI) Validate() (*Response, error) {
func (z *ZhiPuAI) Validate() (llms.Response, error) {
url := BuildAPIURL(ZhiPuAILite, ZhiPuAIInvoke)
token, err := GenerateToken(z.apiKey, APITokenTTLSeconds)
if err != nil {
Expand Down

0 comments on commit d5a2bfb

Please sign in to comment.