Skip to content

Commit

Permalink
Merge pull request #574 from bjwswang/zhipuai
Browse files Browse the repository at this point in the history
chore: add new models supported by zhipuai
  • Loading branch information
bjwswang authored Jan 18, 2024
2 parents ac8f685 + 497f112 commit 517d75b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
3 changes: 1 addition & 2 deletions pkg/langchainwrap/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ func GetLangchainLLM(ctx context.Context, llm *v1alpha1.LLM, c client.Client, cl
}
switch llm.Spec.Type {
case llms.ZhiPuAI:
z := zhipuai.NewZhiPuAI(apiKey)
return &zhipuai.ZhiPuAILLM{ZhiPuAI: *z, RetryTimes: 3}, nil
return zhipuai.NewZhiPuAILLM(apiKey, zhipuai.WithRetryTimes(3)), nil
case llms.OpenAI:
return openai.New(openai.WithToken(apiKey), openai.WithBaseURL(llm.Spec.Endpoint.URL))
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ var (
ZhiPuAIStd string = "chatglm_std"
ZhiPuAIPro string = "chatglm_pro"
ZhiPuAITurbo string = "chatglm_turbo"
// ChatGLM3
ZhiPuAIGLM3Turbo string = "glm-3-turbo"
// ChatGLM4
ZhiPuAIGLM4 string = "glm-4"
// Character LLM
ZhiPuAICharGLM3 string = "charglm-3"
)
var ZhiPuAIModels = []string{ZhiPuAILite, ZhiPuAIStd, ZhiPuAIPro, ZhiPuAITurbo}

Expand Down
42 changes: 33 additions & 9 deletions pkg/llms/zhipuai/langchainllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,35 @@ var (
_ langchainllm.LLM = (*ZhiPuAILLM)(nil)
)

type options struct {
retryTimes int
}

type Option func(*options)

func WithRetryTimes(retryTimes int) Option {
return func(o *options) {
o.retryTimes = retryTimes
}
}

type ZhiPuAILLM struct {
ZhiPuAI
RetryTimes int
c *ZhiPuAI
options *options
}

func NewZhiPuAILLM(apiKey string, opts ...Option) *ZhiPuAILLM {
z := &ZhiPuAILLM{
c: NewZhiPuAI(apiKey),
options: &options{
// 2 times by default
retryTimes: 2,
},
}
for _, opt := range opts {
opt(z.options)
}
return z
}

func (z *ZhiPuAILLM) GetNumTokens(text string) int {
Expand Down Expand Up @@ -83,12 +109,10 @@ func (z *ZhiPuAILLM) Generate(ctx context.Context, prompts []string, options ...
}
for _, prompt := range prompts {
params.Prompt = append(params.Prompt, Prompt{Role: User, Content: prompt})
klog.Infof("get prompts: %#v\n", params.Prompt)
client := NewZhiPuAI(z.apiKey)
needStream := opts.StreamingFunc != nil
if needStream {
res := bytes.NewBuffer(nil)
err := client.SSEInvoke(params, func(event *sse.Event) {
err := z.c.SSEInvoke(params, func(event *sse.Event) {
if string(event.Event) == "finish" {
return
}
Expand All @@ -109,18 +133,18 @@ func (z *ZhiPuAILLM) Generate(ctx context.Context, prompts []string, options ...
i := 0
for {
i++
resp, err = client.Invoke(params)
resp, err = z.c.Invoke(params)
if err != nil {
return nil, err
}
if resp == nil {
return nil, ErrEmptyResponse
}
if resp.Data == nil {
klog.Errorf("zhipullm get empty response: msg:%s code:%d\n", resp.Msg, resp.Code)
if i <= z.RetryTimes && (resp.Code == CodeConcurrencyHigh || resp.Code == CodefrequencyHigh || resp.Code == CodeTimesHigh) {
klog.Errorf("empty response: msg:%s code:%d\n", resp.Msg, resp.Code)
if i <= z.options.retryTimes {
r := rand.Intn(5)
klog.Infof("zhipullm triggers retry[%d], sleep %d seconds, then recall...\n", i, r)
klog.Infof("retry[%d], sleep %d seconds, then recall...\n", i, r)
time.Sleep(time.Duration(r) * time.Second)
continue
}
Expand Down

0 comments on commit 517d75b

Please sign in to comment.