diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go index 84c965d45..49acc1a91 100644 --- a/apiserver/docs/docs.go +++ b/apiserver/docs/docs.go @@ -97,7 +97,7 @@ const docTemplate = `{ "schema": { "type": "array", "items": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } } }, @@ -188,7 +188,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } }, "400": { @@ -1012,37 +1012,6 @@ const docTemplate = `{ } } }, - "chat.Conversation": { - "type": "object", - "properties": { - "app_name": { - "type": "string", - "example": "chat-with-llm" - }, - "app_namespace": { - "type": "string", - "example": "arcadia" - }, - "id": { - "type": "string", - "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.Message" - } - }, - "started_at": { - "type": "string", - "example": "2023-12-21T10:21:06.389359092+08:00" - }, - "updated_at": { - "type": "string", - "example": "2023-12-22T10:21:06.389359092+08:00" - } - } - }, "chat.ConversationReqBody": { "type": "object", "required": [ @@ -1076,29 +1045,6 @@ const docTemplate = `{ } } }, - "chat.Message": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "example": "旷工最小计算单位为0.5天。" - }, - "id": { - "type": "string", - "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" - }, - "query": { - "type": "string", - "example": "旷工最小计算单位为多少天?" - }, - "references": { - "type": "array", - "items": { - "$ref": "#/definitions/retriever.Reference" - } - } - } - }, "chat.MessageReqBody": { "type": "object", "required": [ @@ -1341,6 +1287,60 @@ const docTemplate = `{ "type": "string" } } + }, + "storage.Conversation": { + "type": "object", + "properties": { + "app_name": { + "type": "string", + "example": "chat-with-llm" + }, + "app_namespace": { + "type": "string", + "example": "arcadia" + }, + "id": { + "type": "string", + "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/storage.Message" + } + }, + "started_at": { + "type": "string", + "example": "2023-12-21T10:21:06.389359092+08:00" + }, + "updated_at": { + "type": "string", + "example": "2023-12-22T10:21:06.389359092+08:00" + } + } + }, + "storage.Message": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "example": "旷工最小计算单位为0.5天。" + }, + "id": { + "type": "string", + "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" + }, + "query": { + "type": "string", + "example": "旷工最小计算单位为多少天?" + }, + "references": { + "type": "array", + "items": { + "$ref": "#/definitions/retriever.Reference" + } + } + } } } }` diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json index 83ee95fc0..f411f05fd 100644 --- a/apiserver/docs/swagger.json +++ b/apiserver/docs/swagger.json @@ -86,7 +86,7 @@ "schema": { "type": "array", "items": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } } }, @@ -177,7 +177,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } }, "400": { @@ -1001,37 +1001,6 @@ } } }, - "chat.Conversation": { - "type": "object", - "properties": { - "app_name": { - "type": "string", - "example": "chat-with-llm" - }, - "app_namespace": { - "type": "string", - "example": "arcadia" - }, - "id": { - "type": "string", - "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.Message" - } - }, - "started_at": { - "type": "string", - "example": "2023-12-21T10:21:06.389359092+08:00" - }, - "updated_at": { - "type": "string", - "example": "2023-12-22T10:21:06.389359092+08:00" - } - } - }, "chat.ConversationReqBody": { "type": "object", "required": [ @@ -1065,29 +1034,6 @@ } } }, - "chat.Message": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "example": "旷工最小计算单位为0.5天。" - }, - "id": { - "type": "string", - "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" - }, - "query": { - "type": "string", - "example": "旷工最小计算单位为多少天?" - }, - "references": { - "type": "array", - "items": { - "$ref": "#/definitions/retriever.Reference" - } - } - } - }, "chat.MessageReqBody": { "type": "object", "required": [ @@ -1330,6 +1276,60 @@ "type": "string" } } + }, + "storage.Conversation": { + "type": "object", + "properties": { + "app_name": { + "type": "string", + "example": "chat-with-llm" + }, + "app_namespace": { + "type": "string", + "example": "arcadia" + }, + "id": { + "type": "string", + "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/storage.Message" + } + }, + "started_at": { + "type": "string", + "example": "2023-12-21T10:21:06.389359092+08:00" + }, + "updated_at": { + "type": "string", + "example": "2023-12-22T10:21:06.389359092+08:00" + } + } + }, + "storage.Message": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "example": "旷工最小计算单位为0.5天。" + }, + "id": { + "type": "string", + "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" + }, + "query": { + "type": "string", + "example": "旷工最小计算单位为多少天?" + }, + "references": { + "type": "array", + "items": { + "$ref": "#/definitions/retriever.Reference" + } + } + } } } } \ No newline at end of file diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml index de9d73469..2cd9d1f06 100644 --- a/apiserver/docs/swagger.yaml +++ b/apiserver/docs/swagger.yaml @@ -67,28 +67,6 @@ definitions: $ref: '#/definitions/retriever.Reference' type: array type: object - chat.Conversation: - properties: - app_name: - example: chat-with-llm - type: string - app_namespace: - example: arcadia - type: string - id: - example: 5a41f3ca-763b-41ec-91c3-4bbbb00736d0 - type: string - messages: - items: - $ref: '#/definitions/chat.Message' - type: array - started_at: - example: "2023-12-21T10:21:06.389359092+08:00" - type: string - updated_at: - example: "2023-12-22T10:21:06.389359092+08:00" - type: string - type: object chat.ConversationReqBody: properties: app_name: @@ -113,22 +91,6 @@ definitions: example: conversation is not found type: string type: object - chat.Message: - properties: - answer: - example: 旷工最小计算单位为0.5天。 - type: string - id: - example: 4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24 - type: string - query: - example: 旷工最小计算单位为多少天? - type: string - references: - items: - $ref: '#/definitions/retriever.Reference' - type: array - type: object chat.MessageReqBody: properties: app_name: @@ -298,6 +260,44 @@ definitions: uploadID: type: string type: object + storage.Conversation: + properties: + app_name: + example: chat-with-llm + type: string + app_namespace: + example: arcadia + type: string + id: + example: 5a41f3ca-763b-41ec-91c3-4bbbb00736d0 + type: string + messages: + items: + $ref: '#/definitions/storage.Message' + type: array + started_at: + example: "2023-12-21T10:21:06.389359092+08:00" + type: string + updated_at: + example: "2023-12-22T10:21:06.389359092+08:00" + type: string + type: object + storage.Message: + properties: + answer: + example: 旷工最小计算单位为0.5天。 + type: string + id: + example: 4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24 + type: string + query: + example: 旷工最小计算单位为多少天? + type: string + references: + items: + $ref: '#/definitions/retriever.Reference' + type: array + type: object info: contact: {} paths: @@ -355,7 +355,7 @@ paths: description: OK schema: items: - $ref: '#/definitions/chat.Conversation' + $ref: '#/definitions/storage.Conversation' type: array "400": description: Bad Request @@ -415,7 +415,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/chat.Conversation' + $ref: '#/definitions/storage.Conversation' "400": description: Bad Request schema: diff --git a/apiserver/pkg/chat/chat.go b/apiserver/pkg/chat/chat.go deleted file mode 100644 index 1afcbc8d2..000000000 --- a/apiserver/pkg/chat/chat.go +++ /dev/null @@ -1,180 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package chat - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/tmc/langchaingo/memory" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/klog/v2" - - "github.com/kubeagi/arcadia/api/base/v1alpha1" - "github.com/kubeagi/arcadia/apiserver/pkg/auth" - "github.com/kubeagi/arcadia/apiserver/pkg/client" - "github.com/kubeagi/arcadia/pkg/appruntime" - "github.com/kubeagi/arcadia/pkg/appruntime/base" - "github.com/kubeagi/arcadia/pkg/appruntime/retriever" -) - -var ( - mu sync.Mutex - Conversations = map[string]Conversation{} -) - -func AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messageID string) (*ChatRespBody, error) { - token := auth.ForOIDCToken(ctx) - c, err := client.GetClient(token) - if err != nil { - return nil, fmt.Errorf("failed to get a dynamic client: %w", err) - } - obj, err := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "applications"}). - Namespace(req.AppNamespace).Get(ctx, req.APPName, metav1.GetOptions{}) - if err != nil { - return nil, fmt.Errorf("failed to get application: %w", err) - } - app := &v1alpha1.Application{} - err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), app) - if err != nil { - return nil, fmt.Errorf("failed to convert application: %w", err) - } - if !app.Status.IsReady() { - return nil, errors.New("application is not ready") - } - var conversation Conversation - currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - if !req.NewChat { - var ok bool - conversation, ok = Conversations[req.ConversationID] - if !ok { - return nil, errors.New("conversation is not found") - } - if currentUser != "" && currentUser != conversation.User { - return nil, errors.New("conversation id not match with user") - } - if conversation.AppName != req.APPName || conversation.AppNamespce != req.AppNamespace { - return nil, errors.New("conversation id not match with app info") - } - if conversation.Debug != req.Debug { - return nil, errors.New("conversation id not match with debug") - } - } else { - conversation = Conversation{ - ID: req.ConversationID, - AppName: req.APPName, - AppNamespce: req.AppNamespace, - StartedAt: time.Now(), - UpdatedAt: time.Now(), - Messages: make([]Message, 0), - History: memory.NewChatMessageHistory(), - User: currentUser, - Debug: req.Debug, - } - } - - conversation.Messages = append(conversation.Messages, Message{ - ID: messageID, - Query: req.Query, - Answer: "", - }) - ctx = base.SetAppNamespace(ctx, req.AppNamespace) - appRun, err := appruntime.NewAppOrGetFromCache(ctx, c, app) - if err != nil { - return nil, err - } - klog.FromContext(ctx).Info("begin to run application", "appName", req.APPName, "appNamespace", req.AppNamespace) - out, err := appRun.Run(ctx, c, respStream, appruntime.Input{Question: req.Query, NeedStream: req.ResponseMode.IsStreaming(), History: conversation.History}) - if err != nil { - return nil, err - } - - conversation.UpdatedAt = time.Now() - conversation.Messages[len(conversation.Messages)-1].Answer = out.Answer - conversation.Messages[len(conversation.Messages)-1].References = out.References - mu.Lock() - Conversations[conversation.ID] = conversation - mu.Unlock() - return &ChatRespBody{ - ConversationID: conversation.ID, - MessageID: messageID, - Message: out.Answer, - CreatedAt: time.Now(), - References: out.References, - }, nil -} - -func ListConversations(ctx context.Context, req APPMetadata) ([]Conversation, error) { - conversations := make([]Conversation, 0) - currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - for _, c := range Conversations { - if !c.Debug && c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && (currentUser == "" || currentUser == c.User) { - conversations = append(conversations, c) - } - } - mu.Unlock() - return conversations, nil -} - -func DeleteConversation(ctx context.Context, conversationID string) error { - currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - defer mu.Unlock() - c, ok := Conversations[conversationID] - if ok && (currentUser == "" || currentUser == c.User) { - delete(Conversations, c.ID) - return nil - } else { - return errors.New("conversation is not found") - } -} - -func ListMessages(ctx context.Context, req ConversationReqBody) (Conversation, error) { - currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - defer mu.Unlock() - for _, c := range Conversations { - if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && req.ConversationID == c.ID && (currentUser == "" || currentUser == c.User) { - return c, nil - } - } - return Conversation{}, errors.New("conversation is not found") -} - -func GetMessageReferences(ctx context.Context, req MessageReqBody) ([]retriever.Reference, error) { - currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - defer mu.Unlock() - for _, c := range Conversations { - if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && c.ID == req.ConversationID && (currentUser == "" || currentUser == c.User) { - for _, m := range c.Messages { - if m.ID == req.MessageID { - return m.References, nil - } - } - } - } - return nil, errors.New("conversation or message is not found") -} - -// todo Reuse the flow without having to rebuild req same, not finish, Flow doesn't start with/contain nodes that depend on incomingInput.question diff --git a/apiserver/pkg/chat/chat_server.go b/apiserver/pkg/chat/chat_server.go new file mode 100644 index 000000000..11a4c8025 --- /dev/null +++ b/apiserver/pkg/chat/chat_server.go @@ -0,0 +1,210 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package chat + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/tmc/langchaingo/memory" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/klog/v2" + + "github.com/kubeagi/arcadia/api/base/v1alpha1" + "github.com/kubeagi/arcadia/apiserver/pkg/auth" + "github.com/kubeagi/arcadia/apiserver/pkg/chat/storage" + "github.com/kubeagi/arcadia/apiserver/pkg/client" + "github.com/kubeagi/arcadia/pkg/appruntime" + "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" + pkgconfig "github.com/kubeagi/arcadia/pkg/config" + "github.com/kubeagi/arcadia/pkg/datasource" +) + +var once sync.Once + +type ChatServer struct { + cli dynamic.Interface + storage storage.Storage +} + +func NewChatServer(cli dynamic.Interface) *ChatServer { + return &ChatServer{ + cli: cli, + } +} +func (cs *ChatServer) Storage() storage.Storage { + if cs.storage == nil { + once.Do(func() { + ctx := context.TODO() + ds, err := pkgconfig.GetRelationalDatasource(ctx, nil, cs.cli) + if err != nil || ds == nil { + if err != nil { + klog.Infof("get relational datasource failed: %s, use memory storage for chat", err.Error()) + } else if ds == nil { + klog.Infoln("no relational datasource found, use memory storage for chat") + } + cs.storage = storage.NewMemoryStorage() + } + pg, err := datasource.GetPostgreSQLPool(ctx, nil, cs.cli, ds) + if err != nil { + klog.Errorf("get postgresql pool failed : %s", err.Error()) + cs.storage = storage.NewMemoryStorage() + return + } + conn, err := pg.Pool.Acquire(ctx) + if err != nil { + klog.Errorf("postgresql pool acquire failed : %s", err.Error()) + cs.storage = storage.NewMemoryStorage() + return + } + db, err := storage.NewPostgreSQLStorage(conn.Conn()) + if err != nil { + klog.Errorf("storage.NewPostgreSQLStorage failed : %s", err.Error()) + cs.storage = storage.NewMemoryStorage() + return + } + klog.Infoln("use pg as chat storage.") + cs.storage = db + }) + } + return cs.storage +} + +func (cs *ChatServer) AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messageID string) (*ChatRespBody, error) { + token := auth.ForOIDCToken(ctx) + c, err := client.GetClient(token) + if err != nil { + return nil, fmt.Errorf("failed to get a dynamic client: %w", err) + } + obj, err := c.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "applications"}). + Namespace(req.AppNamespace).Get(ctx, req.APPName, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to get application: %w", err) + } + app := &v1alpha1.Application{} + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), app) + if err != nil { + return nil, fmt.Errorf("failed to convert application: %w", err) + } + if !app.Status.IsReady() { + return nil, errors.New("application is not ready") + } + var conversation *storage.Conversation + history := memory.NewChatMessageHistory() + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + if !req.NewChat { + search := []storage.SearchOption{ + storage.WithAppName(req.APPName), + storage.WithAppNamespace(req.AppNamespace), + storage.WithDebug(req.Debug), + } + if currentUser != "" { + search = append(search, storage.WithUser(currentUser)) + } + conversation, err = cs.Storage().FindExistingConversation(req.ConversationID, search...) + if err != nil { + return nil, err + } + for _, v := range conversation.Messages { + _ = history.AddUserMessage(ctx, v.Query) + _ = history.AddAIMessage(ctx, v.Answer) + } + } else { + conversation = &storage.Conversation{ + ID: req.ConversationID, + AppName: req.APPName, + AppNamespace: req.AppNamespace, + StartedAt: time.Now(), + UpdatedAt: time.Now(), + Messages: make([]storage.Message, 0), + User: currentUser, + Debug: req.Debug, + } + } + conversation.Messages = append(conversation.Messages, storage.Message{ + ID: messageID, + Query: req.Query, + Answer: "", + }) + ctx = base.SetAppNamespace(ctx, req.AppNamespace) + appRun, err := appruntime.NewAppOrGetFromCache(ctx, c, app) + if err != nil { + return nil, err + } + klog.FromContext(ctx).Info("begin to run application", "appName", req.APPName, "appNamespace", req.AppNamespace) + out, err := appRun.Run(ctx, c, respStream, appruntime.Input{Question: req.Query, NeedStream: req.ResponseMode.IsStreaming(), History: history}) + if err != nil { + return nil, err + } + + conversation.UpdatedAt = time.Now() + conversation.Messages[len(conversation.Messages)-1].Answer = out.Answer + conversation.Messages[len(conversation.Messages)-1].References = out.References + if err := cs.Storage().UpdateConversation(conversation); err != nil { + return nil, err + } + return &ChatRespBody{ + ConversationID: conversation.ID, + MessageID: messageID, + Message: out.Answer, + CreatedAt: time.Now(), + References: out.References, + }, nil +} + +func (cs *ChatServer) ListConversations(ctx context.Context, req APPMetadata) ([]storage.Conversation, error) { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + return cs.Storage().ListConversations(storage.WithAppNamespace(req.AppNamespace), storage.WithAppName(req.APPName), storage.WithUser(currentUser), storage.WithUser(currentUser)) +} + +func (cs *ChatServer) DeleteConversation(ctx context.Context, conversationID string) error { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + return cs.Storage().Delete(storage.WithConversationID(conversationID), storage.WithUser(currentUser)) +} + +func (cs *ChatServer) ListMessages(ctx context.Context, req ConversationReqBody) (storage.Conversation, error) { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + c, err := cs.Storage().FindExistingConversation(req.ConversationID, storage.WithAppNamespace(req.AppNamespace), storage.WithAppName(req.APPName), storage.WithAppNamespace(req.AppNamespace), storage.WithUser(currentUser)) + if err != nil { + return storage.Conversation{}, err + } + if c != nil { + return *c, nil + } + return storage.Conversation{}, errors.New("conversation is not found") +} + +func (cs *ChatServer) GetMessageReferences(ctx context.Context, req MessageReqBody) ([]retriever.Reference, error) { + currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) + m, err := cs.Storage().FindExistingMessage(req.ConversationID, req.MessageID, storage.WithAppNamespace(req.AppNamespace), storage.WithAppName(req.APPName), storage.WithAppNamespace(req.AppNamespace), storage.WithUser(currentUser)) + if err != nil { + return nil, err + } + if m != nil && m.References != nil { + return m.References, nil + } + return nil, errors.New("conversation or message is not found") +} + +// todo Reuse the flow without having to rebuild req same, not finish, Flow doesn't start with/contain nodes that depend on incomingInput.question diff --git a/apiserver/pkg/chat/chat_type.go b/apiserver/pkg/chat/rest_type.go similarity index 70% rename from apiserver/pkg/chat/chat_type.go rename to apiserver/pkg/chat/rest_type.go index 348c1ee73..8c1fa202d 100644 --- a/apiserver/pkg/chat/chat_type.go +++ b/apiserver/pkg/chat/rest_type.go @@ -19,8 +19,6 @@ package chat import ( "time" - "github.com/tmc/langchaingo/memory" - "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) @@ -35,7 +33,6 @@ const ( Blocking ResponseMode = "blocking" // Streaming means the response will use Server-Sent Events Streaming ResponseMode = "streaming" - // todo isFlowValidForStream only some node(llm chain) support streaming ) type APPMetadata struct { @@ -80,25 +77,6 @@ type ChatRespBody struct { References []retriever.Reference `json:"references,omitempty"` } -type Conversation struct { - ID string `json:"id" example:"5a41f3ca-763b-41ec-91c3-4bbbb00736d0"` - AppName string `json:"app_name" example:"chat-with-llm"` - AppNamespce string `json:"app_namespace" example:"arcadia"` - StartedAt time.Time `json:"started_at" example:"2023-12-21T10:21:06.389359092+08:00"` - UpdatedAt time.Time `json:"updated_at" example:"2023-12-22T10:21:06.389359092+08:00"` - Messages []Message `json:"messages"` - History *memory.ChatMessageHistory `json:"-"` - User string `json:"-"` - Debug bool `json:"-"` -} - -type Message struct { - ID string `json:"id" example:"4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24"` - Query string `json:"query" example:"旷工最小计算单位为多少天?"` - Answer string `json:"answer" example:"旷工最小计算单位为0.5天。"` - References []retriever.Reference `json:"references,omitempty"` -} - type ErrorResp struct { Err string `json:"error" example:"conversation is not found"` } diff --git a/apiserver/pkg/chat/storage/search.go b/apiserver/pkg/chat/storage/search.go new file mode 100644 index 000000000..8ff0080be --- /dev/null +++ b/apiserver/pkg/chat/storage/search.go @@ -0,0 +1,97 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +type Search struct { + ConversationID *string + MessageID *string + AppName *string + AppNamespace *string + User *string + Debug *bool +} + +type SearchOption func(options *Search) + +func NewSearchOptions(conversationID *string) *Search { + return &Search{ConversationID: conversationID} +} + +func applyOptions(conversationID *string, opts ...SearchOption) *Search { + o := NewSearchOptions(conversationID) + for _, opt := range opts { + opt(o) + } + return o +} + +// WithConversationID returns a Search for setting the ConversationID. +func WithConversationID(id string) SearchOption { + if id == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.ConversationID = &id + } +} + +// WithMessageID returns a Search for setting the MessageID. +func WithMessageID(id string) SearchOption { + if id == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.MessageID = &id + } +} + +// WithAppName returns a Search for setting the AppName. +func WithAppName(name string) SearchOption { + if name == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.AppName = &name + } +} + +// WithAppNamespace returns a Search for setting the AppNamespace. +func WithAppNamespace(name string) SearchOption { + if name == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.AppNamespace = &name + } +} + +// WithUser returns a Search for setting the User. +func WithUser(name string) SearchOption { + if name == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.User = &name + } +} + +// WithDebug returns a Search for setting the Debug. +func WithDebug(debug bool) SearchOption { + return func(o *Search) { + o.Debug = &debug + } +} diff --git a/apiserver/pkg/chat/storage/storage.go b/apiserver/pkg/chat/storage/storage.go new file mode 100644 index 000000000..8ef914084 --- /dev/null +++ b/apiserver/pkg/chat/storage/storage.go @@ -0,0 +1,96 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import ( + "errors" + "time" + + "gorm.io/gorm" + + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" +) + +var ( + ErrConversationNotFound = errors.New("conversation is not found") +) + +// Conversation represent a conversation in storage +type Conversation struct { + ID string `gorm:"column:id;primaryKey;type:uuid;comment:conversation id" json:"id" example:"5a41f3ca-763b-41ec-91c3-4bbbb00736d0"` + AppName string `gorm:"column:app_name;type:string;comment:app name" json:"app_name" example:"chat-with-llm"` + AppNamespace string `gorm:"column:app_namespace;type:string;comment:app namespace" json:"app_namespace" example:"arcadia"` + StartedAt time.Time `gorm:"column:started_at;type:time;autoCreateTime;comment:the time the conversation started at" json:"started_at" example:"2023-12-21T10:21:06.389359092+08:00"` + UpdatedAt time.Time `gorm:"column:updated_at;type:time;autoUpdateTime;comment:the time the conversation updated at" json:"updated_at" example:"2023-12-22T10:21:06.389359092+08:00"` + Messages []Message `gorm:"foreignKey:ConversationID" json:"messages"` + User string `gorm:"column:user;type:string;comment:the conversation chat user" json:"-"` + Debug bool `gorm:"column:debug;type:bool;comment:debug mode" json:"-"` + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:time;comment:the time the conversation deleted at" json:"-"` +} + +// Message represent a message in storage +type Message struct { + ID string `gorm:"column:id;primaryKey;type:uuid;comment:message id" json:"id" example:"4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24"` + Query string `gorm:"column:query;type:string;comment:user input" json:"query" example:"旷工最小计算单位为多少天?"` + Answer string `gorm:"column:answer;type:string;comment:ai response" json:"answer" example:"旷工最小计算单位为0.5天。"` + References References `gorm:"column:references;type:json;comment:references" json:"references,omitempty"` + ConversationID string `gorm:"column:conversation_id;type:uuid;comment:conversation id" json:"-"` +} + +type References []retriever.Reference + +func (Conversation) TableName() string { + return "app_chat_conversation" +} + +func (Message) TableName() string { + return "app_chat_message" +} + +type Storage interface { + ConversationStorage + MessageStorage +} + +// ConversationStorage interface +type ConversationStorage interface { + // FindExistingConversation searches for an existing conversation by ConversationID. + // + // ConversationID string, opts ...SearchOption + // *Conversation, error + FindExistingConversation(ID string, opts ...SearchOption) (*Conversation, error) + // Delete deletes a conversation with the given options. + // + // It takes variadic SearchOption parameter(s) and returns an error. + // **not** return error if the conversation is not found + Delete(opts ...SearchOption) error + // UpdateConversation updates the Conversation. + // + // It takes a pointer to a Conversation and returns an error. + UpdateConversation(*Conversation) error + // ListConversations returns a list of conversations based on the provided options. + // + // It accepts SearchOption(s) and returns a slice of Conversation and an error. + ListConversations(opts ...SearchOption) ([]Conversation, error) +} + +type MessageStorage interface { + // FindExistingMessage finds a message in the conversation. + // + // It takes ConversationID, messageID string parameters and returns *Message, error. + FindExistingMessage(ConversationID, messageID string, opts ...SearchOption) (*Message, error) +} diff --git a/apiserver/pkg/chat/storage/storage_memory.go b/apiserver/pkg/chat/storage/storage_memory.go new file mode 100644 index 000000000..df4c3ebfc --- /dev/null +++ b/apiserver/pkg/chat/storage/storage_memory.go @@ -0,0 +1,153 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import "sync" + +var _ Storage = (*MemoryStorage)(nil) + +type MemoryStorage struct { + mu sync.Mutex + conversations map[string]Conversation +} + +// NewMemoryStorage creates a new MemoryStorage instance. +// +// No parameters. +// Returns a pointer to MemoryStorage. +func NewMemoryStorage() *MemoryStorage { + return &MemoryStorage{ + conversations: make(map[string]Conversation), + } +} + +// ListConversations retrieves conversations from MemoryStorage based on the provided options. +// It takes in optional SearchOption(s) and returns a slice of Conversation and an error. +func (m *MemoryStorage) ListConversations(opts ...SearchOption) (conversations []Conversation, err error) { + searchOpt := applyOptions(nil, opts...) + m.mu.Lock() + for _, c := range m.conversations { + if searchOpt.ConversationID != nil && c.ID != *searchOpt.ConversationID { + continue + } + if searchOpt.AppName != nil && c.AppName != *searchOpt.AppName { + continue + } + if searchOpt.AppNamespace != nil && c.AppNamespace != *searchOpt.AppNamespace { + continue + } + if searchOpt.User != nil && c.User != *searchOpt.User { + continue + } + if searchOpt.Debug != nil && c.Debug != *searchOpt.Debug { + continue + } + conversations = append(conversations, c) + } + m.mu.Unlock() + return conversations, nil +} + +// UpdateConversation updates a conversation in the MemoryStorage. +// +// It takes a pointer to a Conversation as a parameter and returns an error. +func (m *MemoryStorage) UpdateConversation(conversation *Conversation) error { + m.mu.Lock() + m.conversations[conversation.ID] = *conversation + m.mu.Unlock() + return nil +} + +func (m *MemoryStorage) FindExistingMessage(conversationID string, messageID string, opts ...SearchOption) (*Message, error) { + conversation, err := m.FindExistingConversation(conversationID, opts...) + if err != nil { + return nil, err + } + for _, v := range conversation.Messages { + v := v + if v.ID == messageID { + return &v, nil + } + } + return nil, nil +} + +// Delete deletes a conversation from MemoryStorage based on the provided options. +// +// Parameter(s): opts ...SearchOption +// Return type(s): error +func (m *MemoryStorage) Delete(opts ...SearchOption) (err error) { + searchOpt := applyOptions(nil, opts...) + var c *Conversation + if searchOpt.ConversationID != nil { + con, ok := m.conversations[*searchOpt.ConversationID] + if !ok { + return nil + } + c = &con + } else { + c, err = m.FindExistingConversation("", opts...) + if err != nil { + return err + } + if c == nil { + return + } + } + if searchOpt.User != nil && c.User != *searchOpt.User { + return + } + if searchOpt.AppName != nil && c.AppName != *searchOpt.AppName { + return + } + if searchOpt.AppNamespace != nil && c.AppNamespace != *searchOpt.AppNamespace { + return + } + if searchOpt.Debug != nil && c.Debug != *searchOpt.Debug { + return + } + m.mu.Lock() + delete(m.conversations, c.ID) + m.mu.Unlock() + return nil +} + +// FindExistingConversation searches for an existing conversation in MemoryStorage. +// +// ConversationID string, opt ...SearchOption. Returns *Conversation, error. +func (m *MemoryStorage) FindExistingConversation(conversationID string, opt ...SearchOption) (*Conversation, error) { + searchOpt := applyOptions(&conversationID, opt...) + m.mu.Lock() + v, ok := m.conversations[*searchOpt.ConversationID] + m.mu.Unlock() + if !ok { + return nil, ErrConversationNotFound + } + if searchOpt.Debug != nil && v.Debug != *searchOpt.Debug { + return nil, ErrConversationNotFound + } + if searchOpt.AppName != nil && v.AppName != *searchOpt.AppName { + return nil, ErrConversationNotFound + } + if searchOpt.AppNamespace != nil && v.AppNamespace != *searchOpt.AppNamespace { + return nil, ErrConversationNotFound + } + if searchOpt.User != nil && v.User != *searchOpt.User { + return nil, ErrConversationNotFound + } + return &v, nil +} diff --git a/apiserver/pkg/chat/storage/storage_postgresql.go b/apiserver/pkg/chat/storage/storage_postgresql.go new file mode 100644 index 000000000..02483a24f --- /dev/null +++ b/apiserver/pkg/chat/storage/storage_postgresql.go @@ -0,0 +1,204 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "log" + "os" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" +) + +func (r *References) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return fmt.Errorf("failed to unmarshal JSONB value:%#v", value) + } + + result := make([]retriever.Reference, 0) + err := json.Unmarshal(bytes, &result) + if err != nil { + return err + } + *r = result + return nil +} + +func (r References) Value() (driver.Value, error) { + if r == nil || len([]retriever.Reference(r)) == 0 { + return nil, nil + } + // return nil, nil + return json.Marshal(r) +} + +var _ Storage = (*PostgreSQLStorage)(nil) + +type PostgreSQLStorage struct { + db *gorm.DB +} + +func (p *PostgreSQLStorage) ListConversations(opts ...SearchOption) ([]Conversation, error) { + searchOpt := applyOptions(nil, opts...) + conversationQuery := Conversation{} + if searchOpt.ConversationID != nil { + conversationQuery.ID = *searchOpt.ConversationID + } + if searchOpt.Debug != nil { + conversationQuery.Debug = *searchOpt.Debug + } + if searchOpt.User != nil { + conversationQuery.User = *searchOpt.User + } + if searchOpt.AppName != nil { + conversationQuery.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + conversationQuery.AppNamespace = *searchOpt.AppNamespace + } + conversationQuery.Debug = false + conversationQuery.DeletedAt.Valid = false + res := make([]Conversation, 0) + tx := p.db.Preload("Messages").Find(&res, conversationQuery) + if tx.Error != nil { + return nil, tx.Error + } + return res, nil +} + +func (p *PostgreSQLStorage) UpdateConversation(conversation *Conversation) error { + tx := p.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(conversation) + if tx.Error != nil { + return tx.Error + } + return nil +} + +func NewPostgreSQLStorage(conn *pgx.Conn) (*PostgreSQLStorage, error) { + connPool := stdlib.OpenDB(*conn.Config()) + db, err := gorm.Open(postgres.New(postgres.Config{Conn: connPool}), &gorm.Config{}) + if err != nil { + return nil, err + } + if err := db.AutoMigrate(&Conversation{}, &Message{}); err != nil { + return nil, err + } + customLogger := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + db.Logger = customLogger + return &PostgreSQLStorage{ + db: db, + }, nil +} + +func (p *PostgreSQLStorage) FindExistingConversation(conversationID string, opts ...SearchOption) (*Conversation, error) { + searchOpt := applyOptions(&conversationID, opts...) + conversationQuery := Conversation{ID: conversationID} + if searchOpt.Debug != nil { + conversationQuery.Debug = *searchOpt.Debug + } + if searchOpt.User != nil { + conversationQuery.User = *searchOpt.User + } + if searchOpt.AppName != nil { + conversationQuery.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + conversationQuery.AppNamespace = *searchOpt.AppNamespace + } + conversationQuery.Debug = false + conversationQuery.DeletedAt.Valid = false + res := &Conversation{} + tx := p.db.Preload("Messages").First(res, conversationQuery) + if tx.Error != nil { + return nil, tx.Error + } + return res, nil +} + +func (p *PostgreSQLStorage) Delete(opts ...SearchOption) error { + searchOpt := applyOptions(nil, opts...) + c := &Conversation{} + if searchOpt.ConversationID != nil { + c.ID = *searchOpt.ConversationID + } + if searchOpt.User != nil { + c.User = *searchOpt.User + } + if searchOpt.AppName != nil { + c.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + c.AppNamespace = *searchOpt.AppNamespace + } + if searchOpt.Debug != nil { + c.Debug = *searchOpt.Debug + } + tx := p.db.Select("Messages").Delete(c) + if tx.Error != nil { + return tx.Error + } + return nil +} + +func (p *PostgreSQLStorage) FindExistingMessage(conversationID string, messageID string, opts ...SearchOption) (*Message, error) { + searchOpt := applyOptions(&conversationID, opts...) + conversationQuery := Conversation{ID: conversationID} + if searchOpt.Debug != nil { + conversationQuery.Debug = *searchOpt.Debug + } + if searchOpt.User != nil { + conversationQuery.User = *searchOpt.User + } + if searchOpt.AppName != nil { + conversationQuery.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + conversationQuery.AppNamespace = *searchOpt.AppNamespace + } + conversationQuery.Debug = false + conversationQuery.DeletedAt.Valid = false + conversation := &Conversation{} + message := &Message{} + tx := p.db.First(conversation, conversationQuery) + if tx.Error != nil { + return nil, tx.Error + } + association := p.db.Model(conversation).Association("Messages") + if association.Error != nil { + return nil, association.Error + } + if err := association.Find(message, Message{ID: messageID}); err != nil { + return nil, err + } + return message, nil +} diff --git a/apiserver/service/chat.go b/apiserver/service/chat.go index e1b35499d..c1d26698c 100644 --- a/apiserver/service/chat.go +++ b/apiserver/service/chat.go @@ -26,11 +26,13 @@ import ( "github.com/gin-gonic/gin" "k8s.io/apimachinery/pkg/util/uuid" + "k8s.io/client-go/dynamic" "k8s.io/klog/v2" "github.com/kubeagi/arcadia/apiserver/config" "github.com/kubeagi/arcadia/apiserver/pkg/auth" "github.com/kubeagi/arcadia/apiserver/pkg/chat" + "github.com/kubeagi/arcadia/apiserver/pkg/client" "github.com/kubeagi/arcadia/apiserver/pkg/oidc" "github.com/kubeagi/arcadia/apiserver/pkg/requestid" ) @@ -40,8 +42,16 @@ const ( WaitTimeoutForChatStreaming = 5 ) +type ChatService struct { + server *chat.ChatServer +} + // @BasePath /chat +func NewChatService(cli dynamic.Interface) (*ChatService, error) { + return &ChatService{chat.NewChatServer(cli)}, nil +} + // @Summary chat with application // @Schemes // @Description chat with application @@ -54,7 +64,7 @@ const ( // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router / [post] -func chatHandler() gin.HandlerFunc { +func (cs *ChatService) ChatHandler() gin.HandlerFunc { return func(c *gin.Context) { req := chat.ChatReqBody{} if err := c.ShouldBindJSON(&req); err != nil { @@ -85,7 +95,7 @@ func chatHandler() gin.HandlerFunc { } } }() - response, err = chat.AppRun(c.Request.Context(), req, respStream, messageID) + response, err = cs.server.AppRun(c.Request.Context(), req, respStream, messageID) if err != nil { c.SSEvent("error", chat.ChatRespBody{ MessageID: messageID, @@ -151,7 +161,7 @@ func chatHandler() gin.HandlerFunc { klog.FromContext(c.Request.Context()).Info("end to receive messages") } else { // handle chat blocking mode - response, err = chat.AppRun(c.Request.Context(), req, nil, messageID) + response, err = cs.server.AppRun(c.Request.Context(), req, nil, messageID) if err != nil { c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) klog.FromContext(c.Request.Context()).Error(err, "error resp") @@ -170,11 +180,11 @@ func chatHandler() gin.HandlerFunc { // @Accept json // @Produce json // @Param request body chat.APPMetadata true "query params" -// @Success 200 {object} []chat.Conversation +// @Success 200 {object} []storage.Conversation // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /conversations [post] -func listConversationHandler() gin.HandlerFunc { +func (cs *ChatService) ListConversationHandler() gin.HandlerFunc { return func(c *gin.Context) { req := chat.APPMetadata{} if err := c.ShouldBindJSON(&req); err != nil { @@ -182,7 +192,7 @@ func listConversationHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - resp, err := chat.ListConversations(c, req) + resp, err := cs.server.ListConversations(c, req) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error list conversation") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -204,7 +214,7 @@ func listConversationHandler() gin.HandlerFunc { // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /conversations/:conversationID [delete] -func deleteConversationHandler() gin.HandlerFunc { +func (cs *ChatService) DeleteConversationHandler() gin.HandlerFunc { return func(c *gin.Context) { conversationID := c.Param("conversationID") if conversationID == "" { @@ -213,7 +223,7 @@ func deleteConversationHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - err := chat.DeleteConversation(c, conversationID) + err := cs.server.DeleteConversation(c, conversationID) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error delete conversation") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -231,11 +241,11 @@ func deleteConversationHandler() gin.HandlerFunc { // @Accept json // @Produce json // @Param request body chat.ConversationReqBody true "query params" -// @Success 200 {object} chat.Conversation +// @Success 200 {object} storage.Conversation // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /messages [post] -func historyHandler() gin.HandlerFunc { +func (cs *ChatService) HistoryHandler() gin.HandlerFunc { return func(c *gin.Context) { req := chat.ConversationReqBody{} if err := c.ShouldBindJSON(&req); err != nil { @@ -243,7 +253,7 @@ func historyHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - resp, err := chat.ListMessages(c, req) + resp, err := cs.server.ListMessages(c, req) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error list messages") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -266,7 +276,7 @@ func historyHandler() gin.HandlerFunc { // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /messages/:messageID/references [post] -func referenceHandler() gin.HandlerFunc { +func (cs *ChatService) ReferenceHandler() gin.HandlerFunc { return func(c *gin.Context) { messageID := c.Param("messageID") if messageID == "" { @@ -283,7 +293,7 @@ func referenceHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - resp, err := chat.GetMessageReferences(c, req) + resp, err := cs.server.GetMessageReferences(c, req) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error get message references") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -295,11 +305,21 @@ func referenceHandler() gin.HandlerFunc { } func registerChat(g *gin.RouterGroup, conf config.ServerConfig) { - g.POST("", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatHandler()) // chat with bot + c, err := client.GetClient(nil) + if err != nil { + panic(err) + } + + chatService, err := NewChatService(c) + if err != nil { + panic(err) + } + + g.POST("", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.ChatHandler()) // chat with bot - g.POST("/conversations", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), listConversationHandler()) // list conversations - g.DELETE("/conversations/:conversationID", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), deleteConversationHandler()) // delete conversation + g.POST("/conversations", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.ListConversationHandler()) // list conversations + g.DELETE("/conversations/:conversationID", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.DeleteConversationHandler()) // delete conversation - g.POST("/messages", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), historyHandler()) // messages history - g.POST("/messages/:messageID/references", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), referenceHandler()) // messages reference + g.POST("/messages", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.HistoryHandler()) // messages history + g.POST("/messages/:messageID/references", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.ReferenceHandler()) // messages reference } diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 3eca815d8..6f63fa007 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.2.21 +version: 0.2.22 appVersion: "0.1.0" keywords: diff --git a/deploy/charts/arcadia/templates/apiserver.yaml b/deploy/charts/arcadia/templates/apiserver.yaml index 3da470f3c..c44271dfb 100644 --- a/deploy/charts/arcadia/templates/apiserver.yaml +++ b/deploy/charts/arcadia/templates/apiserver.yaml @@ -35,6 +35,7 @@ spec: command: - "./apiserver" args: + - "--v={{ .Values.apiserver.loglevel }}" - "--enable-playground={{ .Values.apiserver.enableplayground }}" - "--port={{ .Values.apiserver.port }}" - "--playground-endpoint-prefix={{ .Values.apiserver.ingress.path }}" diff --git a/deploy/charts/arcadia/templates/config.yaml b/deploy/charts/arcadia/templates/config.yaml index e2919746a..de03c8802 100644 --- a/deploy/charts/arcadia/templates/config.yaml +++ b/deploy/charts/arcadia/templates/config.yaml @@ -6,6 +6,13 @@ data: kind: Datasource name: '{{ .Release.Name }}-minio' namespace: '{{ .Release.Namespace }}' +{{- if .Values.postgresql.enabled }} + relationalDatasource: + apiGroup: arcadia.kubeagi.k8s.com.cn/v1alpha1 + kind: Datasource + name: {{ .Release.Name }}-postgresql + namespace: {{ .Release.Namespace }} +{{- end }} {{- if gt (len .Values.ray.clusters) 0 }} rayClusters: {{- range .Values.ray.clusters }} diff --git a/deploy/charts/arcadia/values.yaml b/deploy/charts/arcadia/values.yaml index 3ca6a2651..5c247513e 100644 --- a/deploy/charts/arcadia/values.yaml +++ b/deploy/charts/arcadia/values.yaml @@ -25,6 +25,7 @@ controller: # @section graphql and bff server # related project: https://github.com/kubeagi/arcadia/tree/main/apiserver apiserver: + loglevel: 3 image: kubeagi/arcadia:v0.1.0-20240110-0dd9a1f enableplayground: false port: 8081 diff --git a/go.mod b/go.mod index fe28fbb6a..3303431cc 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/go-logr/logr v1.2.3 github.com/gofiber/fiber/v2 v2.50.0 github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/jackc/pgx/v5 v5.4.1 + github.com/jackc/pgx/v5 v5.4.3 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.3 github.com/r3labs/sse/v2 v2.10.0 @@ -26,6 +26,8 @@ require ( github.com/tmc/langchaingo v0.1.3 github.com/valyala/fasthttp v1.50.0 github.com/vektah/gqlparser/v2 v2.5.10 + gorm.io/driver/postgres v1.5.4 + gorm.io/gorm v1.25.5 k8s.io/api v0.24.2 k8s.io/apimachinery v0.24.2 k8s.io/client-go v0.24.2 @@ -57,7 +59,9 @@ require ( github.com/huandu/xstrings v1.3.3 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/puddle/v2 v2.2.0 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/minio/md5-simd v1.1.2 // indirect diff --git a/go.sum b/go.sum index 382b1d5db..98345b58c 100644 --- a/go.sum +++ b/go.sum @@ -432,13 +432,15 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.4.1 h1:oKfB/FhuVtit1bBM3zNRRsZ925ZkMN3HXL+LgLUM9lE= -github.com/jackc/pgx/v5 v5.4.1/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY= -github.com/jackc/puddle/v2 v2.2.0 h1:RdcDk92EJBuBS55nQMMYFXTxwstHug4jkhT5pq8VxPk= -github.com/jackc/puddle/v2 v2.2.0/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= @@ -1282,6 +1284,10 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/config/config.go b/pkg/config/config.go index 2b7c48572..8e31d216e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -48,18 +48,9 @@ var ( ErrNoConfigRayClusters = fmt.Errorf("config RayClusters in comfigmap is not found") ) -func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { - config, err := GetConfig(ctx, c, cli) - if err != nil { - return nil, err - } - name := config.SystemDatasource.Name - var namespace string - if config.SystemDatasource.Namespace != nil { - namespace = *config.SystemDatasource.Namespace - } else { - namespace = utils.GetCurrentNamespace() - } +func getDatasource(ctx context.Context, ref arcadiav1alpha1.TypedObjectReference, c client.Client, cli dynamic.Interface) (ds *arcadiav1alpha1.Datasource, err error) { + name := ref.Name + namespace := ref.GetNamespace(utils.GetCurrentNamespace()) source := &arcadiav1alpha1.Datasource{} if c != nil { if err = c.Get(ctx, client.ObjectKey{Name: name, Namespace: namespace}, source); err != nil { @@ -79,6 +70,22 @@ func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Inter return source, err } +func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { + config, err := GetConfig(ctx, c, cli) + if err != nil { + return nil, err + } + return getDatasource(ctx, config.SystemDatasource, c, cli) +} + +func GetRelationalDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { + config, err := GetConfig(ctx, c, cli) + if err != nil { + return nil, err + } + return getDatasource(ctx, config.RelationalDatasource, c, cli) +} + func GetGateway(ctx context.Context, c client.Client, cli dynamic.Interface) (*Gateway, error) { config, err := GetConfig(ctx, c, cli) if err != nil { diff --git a/pkg/config/config_type.go b/pkg/config/config_type.go index 71d6d85d8..fcdd0a342 100644 --- a/pkg/config/config_type.go +++ b/pkg/config/config_type.go @@ -25,6 +25,9 @@ type Config struct { // SystemDatasource specifies the built-in datasource for Arcadia to host data files and model files SystemDatasource arcadiav1alpha1.TypedObjectReference `json:"systemDatasource,omitempty"` + // RelationalDatasource specifies the built-in datasource(common:postgres) for Arcadia to host relational data + RelationalDatasource arcadiav1alpha1.TypedObjectReference `json:"relationalDatasource,omitempty"` + // Gateway to access LLM api services Gateway *Gateway `json:"gateway,omitempty"` diff --git a/pkg/datasource/postgresql.go b/pkg/datasource/postgresql.go index 6486ba0ff..f264f6ca7 100644 --- a/pkg/datasource/postgresql.go +++ b/pkg/datasource/postgresql.go @@ -31,7 +31,7 @@ import ( var ( _ Datasource = (*PostgreSQL)(nil) - locker sync.Mutex + pgEnvMutex sync.Mutex poolsMutex sync.Mutex pools = make(map[string]*PostgreSQL) ) @@ -88,8 +88,8 @@ func newPostgreSQL(ctx context.Context, c client.Client, dc dynamic.Interface, c pgPassFile = string(data[v1alpha1.PGPASSFILE]) pgSSLPassword = string(data[v1alpha1.PGSSLPASSWORD]) } - locker.Lock() - defer locker.Unlock() + pgEnvMutex.Lock() + defer pgEnvMutex.Unlock() if pgUser != "" { if err := os.Setenv("PGUSER", pgUser); err != nil { return nil, err diff --git a/tests/example-test.sh b/tests/example-test.sh index 0a3b7b197..33f6a3ac6 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -377,9 +377,9 @@ info "8.2.2 QA app using knowledgebase base on pgvector" kubectl apply -f config/samples/app_retrievalqachain_knowledgebase_pgvector.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-with-knowledgebase-pgvector" sleep 3 -getRespInAppChat "base-chat-with-knowledgebase" "arcadia" "旷工最小计算单位为多少天?" "" "true" +getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "旷工最小计算单位为多少天?" "" "true" info "8.2.2.2 When no related doc is found, return retriever.spec.docNullReturn info" -getRespInAppChat "base-chat-with-knowledgebase" "arcadia" "飞天的主演是谁?" "" "false" +getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "飞天的主演是谁?" "" "false" expected=$(kubectl get knowledgebaseretrievers -n arcadia base-chat-with-knowledgebase -o json | jq -r .spec.docNullReturn) if [[ $ai_data != $expected ]]; then echo "when no related doc is found, return retriever.spec.docNullReturn info should be:"$expected ", but resp:"$resp @@ -390,6 +390,7 @@ info "8.3 conversation chat app" kubectl apply -f config/samples/app_llmchain_chat_with_bot.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot" sleep 3 +getRespInAppChat "base-chat-with-bot" "arcadia" "Hi I am Bob" "" "false" getRespInAppChat "base-chat-with-bot" "arcadia" "Hi I am Jim" "" "false" getRespInAppChat "base-chat-with-bot" "arcadia" "What is my name?" ${resp_conversation_id} "false" if [[ $resp != *"Jim"* ]]; then @@ -397,10 +398,31 @@ if [[ $resp != *"Jim"* ]]; then exit 1 fi -info "8.4 check conversation list and message history" -curl -XPOST http://127.0.0.1:8081/chat/conversations --data '{"app_name": "base-chat-with-bot", "app_namespace": "arcadia"}' -data=$(jq -n --arg conversationID "$resp_conversation_id" '{"conversation_id":$conversationID, "app_name": "base-chat-with-bot", "app_namespace": "arcadia"}') -curl -XPOST http://127.0.0.1:8081/chat/messages --data "$data" +info "8.4 check other chat rest api" +info "8.4.1 conversation list" +resp=$(curl -s -XPOST http://127.0.0.1:8081/chat/conversations --data '{"app_name": "base-chat-with-bot", "app_namespace": "arcadia"}') +echo $resp | jq . +delete_conversation_id=$(echo $resp | jq -r '.[0].id') +info "8.4.2 message list" +data=$(jq -n --arg conversationID "$delete_conversation_id" '{"conversation_id":$conversationID, "app_name": "base-chat-with-bot", "app_namespace": "arcadia"}') +resp=$(curl -s -XPOST http://127.0.0.1:8081/chat/messages --data "$data") +echo $resp | jq . +info "8.4.3 message references" +resp=$(curl -s -XPOST http://127.0.0.1:8081/chat/conversations --data '{"app_name": "base-chat-with-knowledgebase-pgvector", "app_namespace": "arcadia"}') +message_id=$(echo $resp | jq -r '.[0].messages[0].id') +conversation_id=$(echo $resp | jq -r '.[0].id') +data=$(jq -n --arg conversationID "$conversation_id" '{"conversation_id":$conversationID, "app_name": "base-chat-with-knowledgebase-pgvector", "app_namespace": "arcadia"}') +resp=$(curl -s -XPOST http://127.0.0.1:8081/chat/messages/$message_id/references --data "$data") +echo $resp | jq . +info "8.4.4 delete conversation" +resp=$(curl -s -XDELETE http://127.0.0.1:8081/chat/conversations/$delete_conversation_id) +echo $resp | jq . +resp=$(curl -s -XPOST http://127.0.0.1:8081/chat/conversations --data '{"app_name": "base-chat-with-bot", "app_namespace": "arcadia"}') +if [[ $resp == *"$delete_conversation_id"* ]]; then + echo "delete conversation failed" + exit 1 +fi + # There is uncertainty in the AI replies, most of the time, it will pass the test, a small percentage of the time, the AI will call names in each reply, causing the test to fail, therefore, temporarily disable the following tests #getRespInAppChat "base-chat-with-bot" "arcadia" "What is your model?" ${resp_conversation_id} "false" #getRespInAppChat "base-chat-with-bot" "arcadia" "Does your model based on gpt-3.5?" ${resp_conversation_id} "false"