Skip to content

Commit

Permalink
Merge pull request #53 from bjwswang/llms
Browse files Browse the repository at this point in the history
feat: enable sse-invoke in zhipuai and standarlize response body
  • Loading branch information
bjwswang authored Aug 23, 2023
2 parents 75639d3 + 746251b commit 8dc2aad
Show file tree
Hide file tree
Showing 12 changed files with 461 additions and 81 deletions.
82 changes: 82 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
name: Integration tests
on:
push:
branches:
- 'main'
- 'release-*'
- 'v*'
- 'master'
pull_request:
branches:
- 'main'
- 'release-*'
- 'v*'
- 'master'

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read

jobs:
check-go-mod:
name: Ensure Go modules synchronicity
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Setup Golang
uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
- name: Download all Go modules
run: |
go mod download
- name: Check for tidyness of go.mod and go.sum
run: |
go mod tidy
git diff --exit-code -- .
build-go:
name: Build & cache Go code
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Setup Golang
uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
- name: Restore go build cache
uses: actions/cache@v3
with:
path: ~/.cache/go-build
key: ${{ runner.os }}-go-build-${{ github.run_id }}
- name: Download all Go modules
run: |
go mod download
- name: Compile all packages
run: make build

lint-go:
permissions:
contents: read # for actions/checkout to fetch code
pull-requests: read # for golangci/golangci-lint-action to fetch pull requests
name: Lint Go code
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Setup Golang
uses: actions/setup-go@v4
with:
go-version-file: 'go.mod'
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: latest
# show only new issues if it's a pull request.
only-new-issues: true
skip-cache: true
5 changes: 2 additions & 3 deletions controllers/prompt_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package controllers

import (
"context"
"encoding/json"
"fmt"

"github.com/go-logr/logr"
Expand Down Expand Up @@ -96,7 +95,7 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom
}

// llm call
var resp = make(map[string]interface{})
var resp llms.Response
var err error
switch llm.Spec.Type {
case llms.ZhiPuAI:
Expand All @@ -123,7 +122,7 @@ func (r *PromptReconciler) CallLLM(ctx context.Context, logger logr.Logger, prom
}
promptDeepCodpy.Status.ConditionedStatus = arcadiav1alpha1.ConditionedStatus{Conditions: []arcadiav1alpha1.Condition{newCond}}
if resp != nil {
promptDeepCodpy.Status.Data, _ = json.Marshal(resp)
promptDeepCodpy.Status.Data = resp.Bytes()
}

return r.Status().Update(ctx, promptDeepCodpy)
Expand Down
77 changes: 68 additions & 9 deletions examples/zhipuai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,89 @@ limitations under the License.
package main

import (
"fmt"
"os"

"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"k8s.io/klog/v2"
)

func main() {
if len(os.Args) == 0 {
panic("api key is empty")
}
apiKey := os.Args[1]

klog.V(0).Info("try `Invoke`")
resp, err := sampleInvoke(apiKey)
if err != nil {
panic(err)
}
fmt.Printf("SampleInvoke: \n %+v\n", resp)
klog.V(0).Info("Response: \n %s\n", resp.String())

klog.V(0).Info("try `AsyncInvoke`")
resp, err = sampleInvokeAsync(apiKey)
if err != nil {
panic(err)
}
// fmt.Printf("sampleInvokeAsync: \n %+v\n", resp)
// taskID := "76997570932704279317856632766629711813"
// resp, err = getInvokeAsyncResult(apiKey, taskID)
// if err != nil {
// panic(err)
// }
// fmt.Printf("getInvokeAsyncResult: \n %+v\n", resp)
klog.V(0).Info("Response: \n %s\n", resp.String())

var taskID string
if resp.Data != nil {
taskID = resp.Data.TaskID
}
if taskID == "" {
panic("Failed to get task id from previous AsyncInvoke response")
}

klog.V(0).Info("try `getInvokeAsyncResult` with previous task id")
resp, err = getInvokeAsyncResult(apiKey, taskID)
if err != nil {
panic(err)
}
klog.V(0).Info("Response: \n %s\n", resp.String())

klog.V(0).Info("try `SSEInvoke` with default handler")
err = sampleSSEInvoke(apiKey)
if err != nil {
panic(err)
}
}

func sampleInvoke(apiKey string) (*zhipuai.Response, error) {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.Prompt = []zhipuai.Prompt{
{Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."},
}
return client.Invoke(params)
}

func sampleInvokeAsync(apiKey string) (*zhipuai.Response, error) {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.Prompt = []zhipuai.Prompt{
{Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."},
}
return client.AsyncInvoke(params)
}

func getInvokeAsyncResult(apiKey string, taskID string) (*zhipuai.Response, error) {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.TaskID = taskID
return client.Get(params)
}

func sampleSSEInvoke(apiKey string) error {
client := zhipuai.NewZhiPuAI(apiKey)
params := zhipuai.DefaultModelParams()
params.Prompt = []zhipuai.Prompt{
{Role: zhipuai.User, Content: "As a kubernetes expert,please answer the following questions."},
}
// you can define a customized `handler` on `Event`
err := client.SSEInvoke(params, nil)
if err != nil {
return err
}
return nil
}
46 changes: 0 additions & 46 deletions examples/zhipuai/samples.go

This file was deleted.

6 changes: 4 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ module github.com/kubeagi/arcadia
go 1.20

require (
github.com/go-logr/logr v1.2.0
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.18.1
github.com/r3labs/sse/v2 v2.10.0
k8s.io/api v0.24.2
k8s.io/apimachinery v0.24.2
k8s.io/client-go v0.24.2
k8s.io/klog/v2 v2.60.1
sigs.k8s.io/controller-runtime v0.12.2
)

Expand All @@ -29,7 +32,6 @@ require (
github.com/evanphx/json-patch v4.12.0+incompatible // indirect
github.com/form3tech-oss/jwt-go v3.2.3+incompatible // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/go-logr/logr v1.2.0 // indirect
github.com/go-logr/zapr v1.2.0 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
github.com/go-openapi/jsonreference v0.19.5 // indirect
Expand Down Expand Up @@ -69,13 +71,13 @@ require (
gomodules.xyz/jsonpatch/v2 v2.2.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
k8s.io/apiextensions-apiserver v0.24.2 // indirect
k8s.io/component-base v0.24.2 // indirect
k8s.io/klog/v2 v2.60.1 // indirect
k8s.io/kube-openapi v0.0.0-20220328201542-3ee0da9b0b42 // indirect
k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9 // indirect
sigs.k8s.io/json v0.0.0-20211208200746-9f7c6b3444d2 // indirect
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3 h1:4jVXhlkAyzOScmCkXBTOLRLTz8EeU+eyjrwB/EPq0VU=
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0=
github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
Expand Down Expand Up @@ -580,6 +582,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down Expand Up @@ -900,6 +903,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y=
gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
6 changes: 6 additions & 0 deletions pkg/llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ const (
OpenAI LLMType = "openai"
ZhiPuAI LLMType = "zhipuai"
)

type Response interface {
Type() LLMType
String() string
Bytes() []byte
}
19 changes: 15 additions & 4 deletions pkg/llms/zhipuai/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"errors"
"fmt"
"time"

"github.com/r3labs/sse/v2"
)

const (
Expand Down Expand Up @@ -63,7 +65,7 @@ func NewZhiPuAI(apiKey string) *ZhiPuAI {
}

// Call wraps a common AI api call
func (z *ZhiPuAI) Call(params ModelParams) (map[string]interface{}, error) {
func (z *ZhiPuAI) Call(params ModelParams) (*Response, error) {
switch params.Method {
case ZhiPuAIInvoke:
return z.Invoke(params)
Expand All @@ -77,7 +79,7 @@ func (z *ZhiPuAI) Call(params ModelParams) (map[string]interface{}, error) {
}

// Invoke calls zhipuai and returns result immediately
func (z *ZhiPuAI) Invoke(params ModelParams) (map[string]interface{}, error) {
func (z *ZhiPuAI) Invoke(params ModelParams) (*Response, error) {
url := BuildAPIURL(params.Model, ZhiPuAIInvoke)
token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS)
if err != nil {
Expand All @@ -88,7 +90,7 @@ func (z *ZhiPuAI) Invoke(params ModelParams) (map[string]interface{}, error) {
}

// AsyncInvoke only returns a task id which can be used to get result of task later
func (z *ZhiPuAI) AsyncInvoke(params ModelParams) (map[string]interface{}, error) {
func (z *ZhiPuAI) AsyncInvoke(params ModelParams) (*Response, error) {
url := BuildAPIURL(params.Model, ZhiPuAIAsyncInvoke)
token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS)
if err != nil {
Expand All @@ -99,7 +101,7 @@ func (z *ZhiPuAI) AsyncInvoke(params ModelParams) (map[string]interface{}, error
}

// Get result of task async-invoke
func (z *ZhiPuAI) Get(params ModelParams) (map[string]interface{}, error) {
func (z *ZhiPuAI) Get(params ModelParams) (*Response, error) {
if params.TaskID == "" {
return nil, errors.New("TaskID is required when running Get with method AsyncInvoke")
}
Expand All @@ -113,3 +115,12 @@ func (z *ZhiPuAI) Get(params ModelParams) (map[string]interface{}, error) {

return Get(url, token, ZHIPUAI_MODEL_Default_Timeout)
}

func (z *ZhiPuAI) SSEInvoke(params ModelParams, handler func(*sse.Event)) error {
url := BuildAPIURL(params.Model, ZhiPuAISSEInvoke)
token, err := GenerateToken(z.apiKey, API_TOKEN_TTL_SECONDS)
if err != nil {
return err
}
return Stream(url, token, params, ZHIPUAI_MODEL_Default_Timeout, nil)
}
Loading

0 comments on commit 8dc2aad

Please sign in to comment.