Skip to content

Commit

Permalink
refactor: Extract LLM validate logic
Browse files Browse the repository at this point in the history
Signed-off-by: Lanture1064 <[email protected]>
  • Loading branch information
Lanture1064 committed Aug 24, 2023
1 parent 90448ef commit 6d2292a
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 57 deletions.
87 changes: 30 additions & 57 deletions controllers/llm_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@ package controllers
import (
"context"
"fmt"
"net/http"

"github.com/go-logr/logr"
"github.com/kubeagi/arcadia/pkg/llms"

Check failure on line 23 in controllers/llm_controller.go

View workflow job for this annotation

GitHub Actions / Lint Go code

File is not `gci`-ed with --skip-generated -s standard -s default -s prefix(github.com/kubeagi/arcadia) -s blank -s dot --custom-order (gci)
"github.com/kubeagi/arcadia/pkg/llms/openai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/builder"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/v1alpha1"
Expand Down Expand Up @@ -81,15 +84,34 @@ func (r *LLMReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
// SetupWithManager sets up the controller with the Manager.
func (r *LLMReconciler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&arcadiav1alpha1.LLM{}).
For(&arcadiav1alpha1.LLM{}, builder.WithPredicates(LLMPredicates{})).
Complete(r)
}

// CheckLLM updates new LLM instance.
func (r *LLMReconciler) CheckLLM(ctx context.Context, logger logr.Logger, instance *arcadiav1alpha1.LLM) error {
logger.Info("Checking LLM instance")
// Check new URL/Auth availability
err := r.TestLLMAvailability(ctx, instance, logger)
var err error
var validator llms.Validator

secret := &corev1.Secret{}
err = r.Get(ctx, types.NamespacedName{Name: instance.Spec.Auth, Namespace: instance.Namespace}, secret)
if err != nil {
return err
}
apiKey := string(secret.Data["apiKey"])

switch instance.Spec.Type {
case llms.OpenAI:
validator = openai.NewOpenAI(apiKey)
case llms.ZhiPuAI:
validator = zhipuai.NewZhiPuAI(apiKey)
default:
return fmt.Errorf("unknown LLM type: %s", instance.Spec.Type)
}

res, err := validator.Validate()
if err != nil {
// Set status to unavailable
instance.Status.SetConditions(arcadiav1alpha1.Condition{
Expand All @@ -105,64 +127,15 @@ func (r *LLMReconciler) CheckLLM(ctx context.Context, logger logr.Logger, instan
Type: arcadiav1alpha1.TypeReady,
Status: corev1.ConditionTrue,
Reason: arcadiav1alpha1.ReasonAvailable,
Message: "Available",
Message: res,
LastTransitionTime: metav1.Now(),
LastSuccessfulTime: metav1.Now(),
})
}
return r.Client.Status().Update(ctx, instance)
}

// TestLLMAvailability tests LLM availability.
func (r *LLMReconciler) TestLLMAvailability(ctx context.Context, instance *arcadiav1alpha1.LLM, logger logr.Logger) error {
logger.Info("Testing LLM availability")

//TODO: change URL & request for different types of LLM instance
// For openai instance, we use the "GET model" api.
// For Zhipuai instance, we send a standard async request.
testURL := instance.Spec.URL + "/v1/models"

if instance.Spec.Auth == "" {
return fmt.Errorf("auth is empty")
}

// get auth by secret name
var auth string
secret := &corev1.Secret{}
err := r.Get(ctx, types.NamespacedName{Name: instance.Spec.Auth, Namespace: instance.Namespace}, secret)
if err != nil {
return err
}

auth = "Bearer " + string(secret.Data["apiKey"])

err = SendTestRequest("GET", testURL, auth)
if err != nil {
return err
}

return nil
return r.Client.Status().Update(ctx, instance)
}

func SendTestRequest(method string, url string, auth string) error {
req, err := http.NewRequest(method, url, nil)
if err != nil {
return err
}

req.Header.Set("Authorization", auth)
req.Header.Set("Content-Type", "application/json")

cli := &http.Client{}
resp, err := cli.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("returns unexpected status code: %d", resp.StatusCode)
}

return nil
type LLMPredicates struct {
predicate.Funcs
}
4 changes: 4 additions & 0 deletions pkg/llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ type Response interface {
String() string
Bytes() []byte
}

type Validator interface {
Validate() (string, error)
}
72 changes: 72 additions & 0 deletions pkg/llms/openai/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Copyright 2023 KubeAGI.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package openai

import (
"fmt"
"net/http"
"time"
)

const (
OPENAI_MODEL_API_URL = "https://api.openai.com/v1"

Check failure on line 26 in pkg/llms/openai/api.go

View workflow job for this annotation

GitHub Actions / Lint Go code

ST1003: should not use ALL_CAPS in Go names; use CamelCase instead (stylecheck)
OPENAI_MODEL_Default_Timeout = 300 * time.Second

Check failure on line 27 in pkg/llms/openai/api.go

View workflow job for this annotation

GitHub Actions / Lint Go code

ST1003: should not use underscores in Go names; const OPENAI_MODEL_Default_Timeout should be OPENAIMODELDefaultTimeout (stylecheck)
)

type OpenAI struct {
apiKey string
}

func NewOpenAI(auth string) *OpenAI {
return &OpenAI{
apiKey: auth,
}
}

func (o *OpenAI) Validate() (string, error) {
// Validate OpenAI type CRD LLM Instance
// instance.Spec.URL should be like "https://api.openai.com/"

if o.apiKey == "" {
// TODO: maybe we should consider local pseudo-openAI LLM worker that doesn't require an apiKey?
return "", fmt.Errorf("auth is empty")
}

testURL := OPENAI_MODEL_API_URL + "/models"
testAuth := "Bearer " + o.apiKey // openAI official requirement

req, err := http.NewRequest("GET", testURL, nil)
if err != nil {
return "", err
}

req.Header.Set("Authorization", testAuth)
req.Header.Set("Content-Type", "application/json")

cli := &http.Client{}
resp, err := cli.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("returns unexpected status code: %d", resp.StatusCode)
}

return "GET model request returns Status OK", nil
}
31 changes: 31 additions & 0 deletions pkg/llms/zhipuai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,34 @@ func (z *ZhiPuAI) SSEInvoke(params ModelParams, handler func(*sse.Event)) error
}
return Stream(url, token, params, ZHIPUAI_MODEL_Default_Timeout, nil)
}

func (z *ZhiPuAI) Validate() (string, error) {
url := BuildAPIURL(ZhiPuAILite, ZhiPuAIAsyncInvoke)
token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS)
if err != nil {
return "", err
}

testPrompt := []Prompt{
{
Role: "user",
Content: "Hello!",
},
}

testParam := ModelParams{
Method: ZhiPuAIAsyncInvoke,
Model: ZhiPuAILite,
Temperature: 0.95,
TopP: 0.7,
Prompt: testPrompt,
}

response, err := Post(url, token, testParam, ZHIPUAI_MODEL_Default_Timeout)
if err != nil {
return "", err
}
responseMsg := "TaskID:" + response.Data.TaskID

return responseMsg, nil
}

0 comments on commit 6d2292a

Please sign in to comment.