Skip to content

Commit

Permalink
Merge pull request kubeagi#883 from bjwswang/main
Browse files Browse the repository at this point in the history
fix: set/deploy a default embedder for kubeagi
  • Loading branch information
bjwswang authored Mar 18, 2024
2 parents a9d9344 + 4afa402 commit b6ebd39
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 27 deletions.
2 changes: 1 addition & 1 deletion config/samples/arcadia_v1alpha1_worker_baichuan2-7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ spec:
image: kubeagi/minio-mc:RELEASE.2023-01-28T20-29-38Z
imagePullPolicy: IfNotPresent
runner:
image: kubeagi/arcadia-fastchat-worker:v0.2.0
image: kubeagi/arcadia-fastchat-worker:v0.2.36
imagePullPolicy: IfNotPresent
resources:
limits:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ spec:
image: kubeagi/minio-mc:RELEASE.2023-01-28T20-29-38Z
imagePullPolicy: IfNotPresent
runner:
image: kubeagi/arcadia-fastchat-worker:v0.2.0
image: kubeagi/arcadia-fastchat-worker:v0.2.36
imagePullPolicy: IfNotPresent
model:
kind: "Models"
Expand Down
2 changes: 1 addition & 1 deletion config/samples/arcadia_v1alpha1_worker_qwen-7b-chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ spec:
image: kubeagi/minio-mc:RELEASE.2023-01-28T20-29-38Z
imagePullPolicy: IfNotPresent
runner:
image: kubeagi/arcadia-fastchat-worker:v0.2.0
image: kubeagi/arcadia-fastchat-worker:v0.2.36
imagePullPolicy: IfNotPresent
resources:
limits:
Expand Down
2 changes: 1 addition & 1 deletion deploy/charts/arcadia/Chart.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ apiVersion: v2
name: arcadia
description: A Helm chart(Also a KubeBB Component) for KubeAGI Arcadia
type: application
version: 0.3.19
version: 0.3.20
appVersion: "0.2.0"

keywords:
Expand Down
1 change: 0 additions & 1 deletion deploy/charts/arcadia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ global settings of arcadia chart.

| Parameter | Description | Default |
| ------------------------ | ------------------------------------------------------------ | ----------- |
| `oss.bucket` | Name of the bucket where data is stored | `"arcadia"` |
| `defaultVectorStoreType` | Defines the default vector database type, currently `chroma` and `pgvector` are available | `pgvector` |

### controller
Expand Down
12 changes: 9 additions & 3 deletions deploy/charts/arcadia/templates/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,20 @@ data:
name: '{{ .Release.Name }}-pgvector-vectorstore'
{{- end }}
namespace: '{{ .Release.Namespace }}'
{{- if .Values.rerank.enabled }}
{{- if .Values.config.embedder.enabled }}
embedder:
apiGroup: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: Embedder
name: {{ .Release.Name }}-embedder
namespace: {{ .Release.Namespace }}
{{- end }}
{{- if .Values.config.rerank.enabled }}
rerank:
apiGroup: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: Model
name: bge-reranker-large
name: {{ .Values.config.rerank.model }}
namespace: {{ .Release.Namespace }}
{{- end }}

#streamlit:
# image: 172.22.96.34/cluster_system/streamlit:v1.29.0
# ingressClassName: portal-ingress
Expand Down
19 changes: 19 additions & 0 deletions deploy/charts/arcadia/templates/post-embedder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{{- if .Values.config.embedder.enabled }}
apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1
kind: Worker
metadata:
name: {{ .Release.Name }}-embedder
namespace: {{ .Release.Namespace }}
annotations:
"helm.sh/hook": post-install
"helm.sh/hook-weight": "2"
spec:
displayName: SystemEmbedder
description: "这是系统默认使用的Embedding模型服务"
type: "fastchat"
replicas: 1
model:
kind: "Models"
name: {{ .Values.config.embedder.model }}
namespace: {{ .Release.Namespace }}
{{- end }}
12 changes: 12 additions & 0 deletions deploy/charts/arcadia/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ global:
- postgresql.<replaced-ingress-nginx-ip>.nip.io
ip: <replaced-ingress-nginx-ip>

# @section config is used to configure the system
config:
# embedder is used as the system default embedding service
embedder:
enabled: true
model: "bge-large-zh-v1.5"
# rerank is the default model for reranking service
rerank:
enabled: true
model: "bge-reranker-large"


# @section controller is used as the core controller for arcadia
# @param image Image to be used
# @param imagePullPolcy ImagePullPolicy
Expand Down
19 changes: 9 additions & 10 deletions pkg/appruntime/chain/llmchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any
instance := l.Instance
options := GetChainOptions(instance.Spec.CommonChainConfig)

needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
}

// Check if have files as input
v3, ok := args["documents"]
if ok {
Expand Down Expand Up @@ -123,12 +117,17 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any
l.LLMChain = *chain

var out string

// Predict based on options
if len(options) > 0 {
needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
out, err = chains.Predict(ctx, l.LLMChain, args)
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
out, err = chains.Predict(ctx, l.LLMChain, args)
}
}

out, err = handleNoErrNoOut(ctx, needStream, out, err, l.LLMChain, args, options)
Expand Down
18 changes: 9 additions & 9 deletions pkg/appruntime/chain/retrievalqachain.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st

instance := l.Instance
options := GetChainOptions(instance.Spec.CommonChainConfig)
needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
}

// Check if have files as input
v5, ok := args["documents"]
Expand Down Expand Up @@ -137,12 +132,17 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st
l.ConversationalRetrievalQA = chain
args["query"] = args["question"]
var out string

// Predict based on options
if len(options) > 0 {
needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...)
} else {
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args)
if len(options) > 0 {
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...)
} else {
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args)
}
}

out, err = handleNoErrNoOut(ctx, needStream, out, err, l.ConversationalRetrievalQA, args, options)
Expand Down

0 comments on commit b6ebd39

Please sign in to comment.