From 59ae4fb2e652a2e784eeb42041bc1d5e8692a105 Mon Sep 17 00:00:00 2001 From: Lanture1064 Date: Wed, 23 Aug 2023 16:23:07 +0800 Subject: [PATCH] refactor: Extract LLM validate logic Signed-off-by: Lanture1064 --- controllers/llm_controller.go | 88 ++++++++++++----------------------- pkg/llms/openai/api.go | 79 +++++++++++++++++++++++++++++++ pkg/llms/openai/object.go | 75 +++++++++++++++++++++++++++++ pkg/llms/openai/response.go | 61 ++++++++++++++++++++++++ pkg/llms/zhipuai/api.go | 59 ++++++++++++++++++----- pkg/llms/zhipuai/jwt_token.go | 4 +- 6 files changed, 296 insertions(+), 70 deletions(-) create mode 100644 pkg/llms/openai/api.go create mode 100644 pkg/llms/openai/object.go create mode 100644 pkg/llms/openai/response.go diff --git a/controllers/llm_controller.go b/controllers/llm_controller.go index 3b71f2817..26aac593a 100644 --- a/controllers/llm_controller.go +++ b/controllers/llm_controller.go @@ -19,17 +19,19 @@ package controllers import ( "context" "fmt" - "net/http" - "github.com/go-logr/logr" + "github.com/kubeagi/arcadia/pkg/llms" + "github.com/kubeagi/arcadia/pkg/llms/zhipuai" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" arcadiav1alpha1 "github.com/kubeagi/arcadia/api/v1alpha1" @@ -81,7 +83,7 @@ func (r *LLMReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R // SetupWithManager sets up the controller with the Manager. func (r *LLMReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). - For(&arcadiav1alpha1.LLM{}). + For(&arcadiav1alpha1.LLM{}, builder.WithPredicates(LLMPredicates{})). Complete(r) } @@ -89,7 +91,28 @@ func (r *LLMReconciler) SetupWithManager(mgr ctrl.Manager) error { func (r *LLMReconciler) CheckLLM(ctx context.Context, logger logr.Logger, instance *arcadiav1alpha1.LLM) error { logger.Info("Checking LLM instance") // Check new URL/Auth availability - err := r.TestLLMAvailability(ctx, instance, logger) + var err error + var response llms.Response + + secret := &corev1.Secret{} + err = r.Get(ctx, types.NamespacedName{Name: instance.Spec.Auth, Namespace: instance.Namespace}, secret) + if err != nil { + return err + } + apiKey := string(secret.Data["apiKey"]) + + switch instance.Spec.Type { + case llms.OpenAI: + // validator := openai.NewOpenAI(apiKey) + // response, err = validator.Validate() + return fmt.Errorf("openAI not implemented yet") + case llms.ZhiPuAI: + validator := zhipuai.NewZhiPuAI(apiKey) + response, err = validator.Validate() + default: + return fmt.Errorf("unknown LLM type: %s", instance.Spec.Type) + } + if err != nil { // Set status to unavailable instance.Status.SetConditions(arcadiav1alpha1.Condition{ @@ -105,64 +128,15 @@ func (r *LLMReconciler) CheckLLM(ctx context.Context, logger logr.Logger, instan Type: arcadiav1alpha1.TypeReady, Status: corev1.ConditionTrue, Reason: arcadiav1alpha1.ReasonAvailable, - Message: "Available", + Message: response.String(), LastTransitionTime: metav1.Now(), LastSuccessfulTime: metav1.Now(), }) } - return r.Client.Status().Update(ctx, instance) -} - -// TestLLMAvailability tests LLM availability. -func (r *LLMReconciler) TestLLMAvailability(ctx context.Context, instance *arcadiav1alpha1.LLM, logger logr.Logger) error { - logger.Info("Testing LLM availability") - - //TODO: change URL & request for different types of LLM instance - // For openai instance, we use the "GET model" api. - // For Zhipuai instance, we send a standard async request. - testURL := instance.Spec.URL + "/v1/models" - - if instance.Spec.Auth == "" { - return fmt.Errorf("auth is empty") - } - - // get auth by secret name - var auth string - secret := &corev1.Secret{} - err := r.Get(ctx, types.NamespacedName{Name: instance.Spec.Auth, Namespace: instance.Namespace}, secret) - if err != nil { - return err - } - - auth = "Bearer " + string(secret.Data["apiKey"]) - err = SendTestRequest("GET", testURL, auth) - if err != nil { - return err - } - - return nil + return r.Client.Status().Update(ctx, instance) } -func SendTestRequest(method string, url string, auth string) error { - req, err := http.NewRequest(method, url, nil) - if err != nil { - return err - } - - req.Header.Set("Authorization", auth) - req.Header.Set("Content-Type", "application/json") - - cli := &http.Client{} - resp, err := cli.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("returns unexpected status code: %d", resp.StatusCode) - } - - return nil +type LLMPredicates struct { + predicate.Funcs } diff --git a/pkg/llms/openai/api.go b/pkg/llms/openai/api.go new file mode 100644 index 000000000..febde3ee7 --- /dev/null +++ b/pkg/llms/openai/api.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openai + +import ( + "fmt" + "net/http" + "time" +) + +const ( + OpenaiModelAPIURL = "https://api.openai.com/v1" + OpenaiDefaultTimeout = 300 * time.Second +) + +type OpenAI struct { + apiKey string +} + +func NewOpenAI(auth string) *OpenAI { + return &OpenAI{ + apiKey: auth, + } +} + +func (o *OpenAI) Validate() (*Response, error) { + // Validate OpenAI type CRD LLM Instance + // instance.Spec.URL should be like "https://api.openai.com/" + + if o.apiKey == "" { + // TODO: maybe we should consider local pseudo-openAI LLM worker that doesn't require an apiKey? + return nil, fmt.Errorf("auth is empty") + } + + testURL := OpenaiModelAPIURL + "/models" + testAuth := "Bearer " + o.apiKey // openAI official requirement + + req, err := http.NewRequest("GET", testURL, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", testAuth) + req.Header.Set("Content-Type", "application/json") + + cli := &http.Client{} + resp, err := cli.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returns unexpected status code: %d", resp.StatusCode) + } + + // FIXME: response object + response, err := parseHTTPResponse(resp) + if err != nil { + return nil, err + } + return response, nil +} + +// TODO: Openai Model Object & Other definition diff --git a/pkg/llms/openai/object.go b/pkg/llms/openai/object.go new file mode 100644 index 000000000..2b67c15bc --- /dev/null +++ b/pkg/llms/openai/object.go @@ -0,0 +1,75 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openai + +// Chat is a chat completion response returned by model. +type Chat struct { + ID string `json:"id"` // A unique identifier for the chat completion + Object string `json:"object"` // The object type, which is always chat.completion + Created int `json:"created"` // A unix timestamp of when the chat completion was created. + Model string `json:"model"` // The model used for the chat completion. + Choices []Choice `json:"choices"` // A list of chat completion choices. Can be more than one if n is greater than 1. + Usage Usage `json:"usage"` // Usage statistics of the completion request. +} + +// ChatStream is a streamed chunk of a chat completion returned by model. +type ChatStream struct { + ID string `json:"id"` // A unique identifier for the chat completion. + Object string `json:"object"` // The object type, which is always chat.completion + Created int `json:"created"` // A unix timestamp of when the chat completion was created. + Model string `json:"model"` // The model used for the chat completion. + Choices []ChoiceStream `json:"choices"` // A list of chat completion choices. Can be more than one if n is greater than 1. +} + +type Choice struct { + Index int `json:"index"` // The index of the choice in the list of choices. + Message Message `json:"message"` // The completion message generated by the model. + FinishReason string `json:"finish_reason"` // The reason the model stopped generating tokens. This will be stop if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, or `function_call` if the model called a function. +} + +type ChoiceStream struct { + Index int `json:"index"` + Delta Delta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +// Message is a chat completion message generated by the model. +type Message struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + FunctionCall FunctionCall `json:"function_call,omitempty"` +} + +// FunctionCall is used when a message is calling a function generated by openAI model. +type FunctionCall struct { + Name string `json:"name"` // Name of the function. + Arguments string `json:"arguments"` // JSON format of the arguments. +} + +// Usage is the usage statistics of the completion request. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Delta is A chat completion delta generated by streamed model responses. +type Delta struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + FunctionCall FunctionCall `json:"function_call,omitempty"` +} diff --git a/pkg/llms/openai/response.go b/pkg/llms/openai/response.go new file mode 100644 index 000000000..87e33623b --- /dev/null +++ b/pkg/llms/openai/response.go @@ -0,0 +1,61 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openai + +import ( + "encoding/json" + "fmt" + "github.com/kubeagi/arcadia/pkg/llms" + "net/http" +) + +type Response struct { + Code int `json:"code"` + Data string `json:"data"` // JSON format of the returned data + Msg string `json:"msg"` + Success bool `json:"success"` +} + +func (response *Response) Type() llms.LLMType { + return llms.OpenAI +} + +func (response *Response) Bytes() []byte { + bytes, err := json.Marshal(response) + if err != nil { + return []byte{} + } + return bytes +} + +func (response *Response) String() string { + return string(response.Bytes()) +} + +func parseHTTPResponse(resp *http.Response) (*Response, error) { + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("exception: %s", resp.Status) + } + + var data = new(Response) + err := json.NewDecoder(resp.Body).Decode(&data) + if err != nil { + return nil, err + } + + return data, nil +} diff --git a/pkg/llms/zhipuai/api.go b/pkg/llms/zhipuai/api.go index 307c541cd..894f0716c 100644 --- a/pkg/llms/zhipuai/api.go +++ b/pkg/llms/zhipuai/api.go @@ -27,8 +27,8 @@ import ( ) const ( - ZHIPUAI_MODEL_API_URL = "https://open.bigmodel.cn/api/paas/v3/model-api" - ZHIPUAI_MODEL_Default_Timeout = 300 * time.Second + ZhipuaiModelAPIURL = "https://open.bigmodel.cn/api/paas/v3/model-api" + ZhipuaiModelDefaultTimeout = 300 * time.Second ) type Model string @@ -51,7 +51,7 @@ const ( ) func BuildAPIURL(model Model, method Method) string { - return fmt.Sprintf("%s/%s/%s", ZHIPUAI_MODEL_API_URL, model, method) + return fmt.Sprintf("%s/%s/%s", ZhipuaiModelAPIURL, model, method) } type ZhiPuAI struct { @@ -81,23 +81,23 @@ func (z *ZhiPuAI) Call(params ModelParams) (*Response, error) { // Invoke calls zhipuai and returns result immediately func (z *ZhiPuAI) Invoke(params ModelParams) (*Response, error) { url := BuildAPIURL(params.Model, ZhiPuAIInvoke) - token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) if err != nil { return nil, err } - return Post(url, token, params, ZHIPUAI_MODEL_Default_Timeout) + return Post(url, token, params, ZhipuaiModelDefaultTimeout) } // AsyncInvoke only returns a task id which can be used to get result of task later func (z *ZhiPuAI) AsyncInvoke(params ModelParams) (*Response, error) { url := BuildAPIURL(params.Model, ZhiPuAIAsyncInvoke) - token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) if err != nil { return nil, err } - return Post(url, token, params, ZHIPUAI_MODEL_Default_Timeout) + return Post(url, token, params, ZhipuaiModelDefaultTimeout) } // Get result of task async-invoke @@ -108,19 +108,56 @@ func (z *ZhiPuAI) Get(params ModelParams) (*Response, error) { // url with task id url := fmt.Sprintf("%s/%s", BuildAPIURL(params.Model, ZhiPuAIAsyncInvoke), params.TaskID) - token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) if err != nil { return nil, err } - return Get(url, token, ZHIPUAI_MODEL_Default_Timeout) + return Get(url, token, ZhipuaiModelDefaultTimeout) } func (z *ZhiPuAI) SSEInvoke(params ModelParams, handler func(*sse.Event)) error { url := BuildAPIURL(params.Model, ZhiPuAISSEInvoke) - token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS) + token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) if err != nil { return err } - return Stream(url, token, params, ZHIPUAI_MODEL_Default_Timeout, nil) + return Stream(url, token, params, ZhipuaiModelDefaultTimeout, nil) +} + +func (z *ZhiPuAI) Validate() (*Response, error) { + url := BuildAPIURL(ZhiPuAILite, ZhiPuAIAsyncInvoke) + token, err := GenerateToken(z.apiKey, APITokenTTLSeconds) + if err != nil { + return nil, err + } + + testPrompt := []Prompt{ + { + Role: "user", + Content: "Hello!", + }, + } + + testParam := ModelParams{ + Method: ZhiPuAIAsyncInvoke, + Model: ZhiPuAILite, + Temperature: 0.95, + TopP: 0.7, + Prompt: testPrompt, + } + + postResponse, err := Post(url, token, testParam, ZhipuaiModelDefaultTimeout) + if err != nil { + return nil, err + } + + testParam.TaskID = postResponse.Data.TaskID + + getResponse, err := z.Get(testParam) + if err != nil { + return nil, err + } + + return getResponse, nil } diff --git a/pkg/llms/zhipuai/jwt_token.go b/pkg/llms/zhipuai/jwt_token.go index b311dc161..60f62d94d 100644 --- a/pkg/llms/zhipuai/jwt_token.go +++ b/pkg/llms/zhipuai/jwt_token.go @@ -26,9 +26,9 @@ import ( ) const ( - API_TOKEN_TTL_SECONDS = 3 * 60 + APITokenTTLSeconds = 3 * 60 // FIXME: impl TLL Cache - CACHE_TTL_SECONDS = (API_TOKEN_TTL_SECONDS - 30) + CacheTTLSeconds = (APITokenTTLSeconds - 30) ) func GenerateToken(apikey string, expSeconds int64) (string, error) {