From 497f112658e573f03a3a56d4ef799b8a71ac4cbb Mon Sep 17 00:00:00 2001 From: bjwswang Date: Thu, 18 Jan 2024 11:42:38 +0800 Subject: [PATCH] chore: add new models supported by zhipuai Signed-off-by: bjwswang --- pkg/langchainwrap/llm.go | 3 +-- pkg/llms/llms.go | 6 +++++ pkg/llms/zhipuai/langchainllm.go | 42 +++++++++++++++++++++++++------- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/pkg/langchainwrap/llm.go b/pkg/langchainwrap/llm.go index b8a0ab304..64a7c63e7 100644 --- a/pkg/langchainwrap/llm.go +++ b/pkg/langchainwrap/llm.go @@ -48,8 +48,7 @@ func GetLangchainLLM(ctx context.Context, llm *v1alpha1.LLM, c client.Client, cl } switch llm.Spec.Type { case llms.ZhiPuAI: - z := zhipuai.NewZhiPuAI(apiKey) - return &zhipuai.ZhiPuAILLM{ZhiPuAI: *z, RetryTimes: 3}, nil + return zhipuai.NewZhiPuAILLM(apiKey, zhipuai.WithRetryTimes(3)), nil case llms.OpenAI: return openai.New(openai.WithToken(apiKey), openai.WithBaseURL(llm.Spec.Endpoint.URL)) } diff --git a/pkg/llms/llms.go b/pkg/llms/llms.go index 55afbfd4e..6717e6564 100644 --- a/pkg/llms/llms.go +++ b/pkg/llms/llms.go @@ -39,6 +39,12 @@ var ( ZhiPuAIStd string = "chatglm_std" ZhiPuAIPro string = "chatglm_pro" ZhiPuAITurbo string = "chatglm_turbo" + // ChatGLM3 + ZhiPuAIGLM3Turbo string = "glm-3-turbo" + // ChatGLM4 + ZhiPuAIGLM4 string = "glm-4" + // Character LLM + ZhiPuAICharGLM3 string = "charglm-3" ) var ZhiPuAIModels = []string{ZhiPuAILite, ZhiPuAIStd, ZhiPuAIPro, ZhiPuAITurbo} diff --git a/pkg/llms/zhipuai/langchainllm.go b/pkg/llms/zhipuai/langchainllm.go index 5b59ffb22..4a98886a5 100644 --- a/pkg/llms/zhipuai/langchainllm.go +++ b/pkg/llms/zhipuai/langchainllm.go @@ -42,9 +42,35 @@ var ( _ langchainllm.LLM = (*ZhiPuAILLM)(nil) ) +type options struct { + retryTimes int +} + +type Option func(*options) + +func WithRetryTimes(retryTimes int) Option { + return func(o *options) { + o.retryTimes = retryTimes + } +} + type ZhiPuAILLM struct { - ZhiPuAI - RetryTimes int + c *ZhiPuAI + options *options +} + +func NewZhiPuAILLM(apiKey string, opts ...Option) *ZhiPuAILLM { + z := &ZhiPuAILLM{ + c: NewZhiPuAI(apiKey), + options: &options{ + // 2 times by default + retryTimes: 2, + }, + } + for _, opt := range opts { + opt(z.options) + } + return z } func (z *ZhiPuAILLM) GetNumTokens(text string) int { @@ -83,12 +109,10 @@ func (z *ZhiPuAILLM) Generate(ctx context.Context, prompts []string, options ... } for _, prompt := range prompts { params.Prompt = append(params.Prompt, Prompt{Role: User, Content: prompt}) - klog.Infof("get prompts: %#v\n", params.Prompt) - client := NewZhiPuAI(z.apiKey) needStream := opts.StreamingFunc != nil if needStream { res := bytes.NewBuffer(nil) - err := client.SSEInvoke(params, func(event *sse.Event) { + err := z.c.SSEInvoke(params, func(event *sse.Event) { if string(event.Event) == "finish" { return } @@ -109,7 +133,7 @@ func (z *ZhiPuAILLM) Generate(ctx context.Context, prompts []string, options ... i := 0 for { i++ - resp, err = client.Invoke(params) + resp, err = z.c.Invoke(params) if err != nil { return nil, err } @@ -117,10 +141,10 @@ func (z *ZhiPuAILLM) Generate(ctx context.Context, prompts []string, options ... return nil, ErrEmptyResponse } if resp.Data == nil { - klog.Errorf("zhipullm get empty response: msg:%s code:%d\n", resp.Msg, resp.Code) - if i <= z.RetryTimes && (resp.Code == CodeConcurrencyHigh || resp.Code == CodefrequencyHigh || resp.Code == CodeTimesHigh) { + klog.Errorf("empty response: msg:%s code:%d\n", resp.Msg, resp.Code) + if i <= z.options.retryTimes { r := rand.Intn(5) - klog.Infof("zhipullm triggers retry[%d], sleep %d seconds, then recall...\n", i, r) + klog.Infof("retry[%d], sleep %d seconds, then recall...\n", i, r) time.Sleep(time.Duration(r) * time.Second) continue }