diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go index 84c965d45..49acc1a91 100644 --- a/apiserver/docs/docs.go +++ b/apiserver/docs/docs.go @@ -97,7 +97,7 @@ const docTemplate = `{ "schema": { "type": "array", "items": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } } }, @@ -188,7 +188,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } }, "400": { @@ -1012,37 +1012,6 @@ const docTemplate = `{ } } }, - "chat.Conversation": { - "type": "object", - "properties": { - "app_name": { - "type": "string", - "example": "chat-with-llm" - }, - "app_namespace": { - "type": "string", - "example": "arcadia" - }, - "id": { - "type": "string", - "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.Message" - } - }, - "started_at": { - "type": "string", - "example": "2023-12-21T10:21:06.389359092+08:00" - }, - "updated_at": { - "type": "string", - "example": "2023-12-22T10:21:06.389359092+08:00" - } - } - }, "chat.ConversationReqBody": { "type": "object", "required": [ @@ -1076,29 +1045,6 @@ const docTemplate = `{ } } }, - "chat.Message": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "example": "旷工最小计算单位为0.5天。" - }, - "id": { - "type": "string", - "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" - }, - "query": { - "type": "string", - "example": "旷工最小计算单位为多少天?" - }, - "references": { - "type": "array", - "items": { - "$ref": "#/definitions/retriever.Reference" - } - } - } - }, "chat.MessageReqBody": { "type": "object", "required": [ @@ -1341,6 +1287,60 @@ const docTemplate = `{ "type": "string" } } + }, + "storage.Conversation": { + "type": "object", + "properties": { + "app_name": { + "type": "string", + "example": "chat-with-llm" + }, + "app_namespace": { + "type": "string", + "example": "arcadia" + }, + "id": { + "type": "string", + "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/storage.Message" + } + }, + "started_at": { + "type": "string", + "example": "2023-12-21T10:21:06.389359092+08:00" + }, + "updated_at": { + "type": "string", + "example": "2023-12-22T10:21:06.389359092+08:00" + } + } + }, + "storage.Message": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "example": "旷工最小计算单位为0.5天。" + }, + "id": { + "type": "string", + "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" + }, + "query": { + "type": "string", + "example": "旷工最小计算单位为多少天?" + }, + "references": { + "type": "array", + "items": { + "$ref": "#/definitions/retriever.Reference" + } + } + } } } }` diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json index 83ee95fc0..f411f05fd 100644 --- a/apiserver/docs/swagger.json +++ b/apiserver/docs/swagger.json @@ -86,7 +86,7 @@ "schema": { "type": "array", "items": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } } }, @@ -177,7 +177,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/chat.Conversation" + "$ref": "#/definitions/storage.Conversation" } }, "400": { @@ -1001,37 +1001,6 @@ } } }, - "chat.Conversation": { - "type": "object", - "properties": { - "app_name": { - "type": "string", - "example": "chat-with-llm" - }, - "app_namespace": { - "type": "string", - "example": "arcadia" - }, - "id": { - "type": "string", - "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/definitions/chat.Message" - } - }, - "started_at": { - "type": "string", - "example": "2023-12-21T10:21:06.389359092+08:00" - }, - "updated_at": { - "type": "string", - "example": "2023-12-22T10:21:06.389359092+08:00" - } - } - }, "chat.ConversationReqBody": { "type": "object", "required": [ @@ -1065,29 +1034,6 @@ } } }, - "chat.Message": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "example": "旷工最小计算单位为0.5天。" - }, - "id": { - "type": "string", - "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" - }, - "query": { - "type": "string", - "example": "旷工最小计算单位为多少天?" - }, - "references": { - "type": "array", - "items": { - "$ref": "#/definitions/retriever.Reference" - } - } - } - }, "chat.MessageReqBody": { "type": "object", "required": [ @@ -1330,6 +1276,60 @@ "type": "string" } } + }, + "storage.Conversation": { + "type": "object", + "properties": { + "app_name": { + "type": "string", + "example": "chat-with-llm" + }, + "app_namespace": { + "type": "string", + "example": "arcadia" + }, + "id": { + "type": "string", + "example": "5a41f3ca-763b-41ec-91c3-4bbbb00736d0" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/definitions/storage.Message" + } + }, + "started_at": { + "type": "string", + "example": "2023-12-21T10:21:06.389359092+08:00" + }, + "updated_at": { + "type": "string", + "example": "2023-12-22T10:21:06.389359092+08:00" + } + } + }, + "storage.Message": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "example": "旷工最小计算单位为0.5天。" + }, + "id": { + "type": "string", + "example": "4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24" + }, + "query": { + "type": "string", + "example": "旷工最小计算单位为多少天?" + }, + "references": { + "type": "array", + "items": { + "$ref": "#/definitions/retriever.Reference" + } + } + } } } } \ No newline at end of file diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml index de9d73469..2cd9d1f06 100644 --- a/apiserver/docs/swagger.yaml +++ b/apiserver/docs/swagger.yaml @@ -67,28 +67,6 @@ definitions: $ref: '#/definitions/retriever.Reference' type: array type: object - chat.Conversation: - properties: - app_name: - example: chat-with-llm - type: string - app_namespace: - example: arcadia - type: string - id: - example: 5a41f3ca-763b-41ec-91c3-4bbbb00736d0 - type: string - messages: - items: - $ref: '#/definitions/chat.Message' - type: array - started_at: - example: "2023-12-21T10:21:06.389359092+08:00" - type: string - updated_at: - example: "2023-12-22T10:21:06.389359092+08:00" - type: string - type: object chat.ConversationReqBody: properties: app_name: @@ -113,22 +91,6 @@ definitions: example: conversation is not found type: string type: object - chat.Message: - properties: - answer: - example: 旷工最小计算单位为0.5天。 - type: string - id: - example: 4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24 - type: string - query: - example: 旷工最小计算单位为多少天? - type: string - references: - items: - $ref: '#/definitions/retriever.Reference' - type: array - type: object chat.MessageReqBody: properties: app_name: @@ -298,6 +260,44 @@ definitions: uploadID: type: string type: object + storage.Conversation: + properties: + app_name: + example: chat-with-llm + type: string + app_namespace: + example: arcadia + type: string + id: + example: 5a41f3ca-763b-41ec-91c3-4bbbb00736d0 + type: string + messages: + items: + $ref: '#/definitions/storage.Message' + type: array + started_at: + example: "2023-12-21T10:21:06.389359092+08:00" + type: string + updated_at: + example: "2023-12-22T10:21:06.389359092+08:00" + type: string + type: object + storage.Message: + properties: + answer: + example: 旷工最小计算单位为0.5天。 + type: string + id: + example: 4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24 + type: string + query: + example: 旷工最小计算单位为多少天? + type: string + references: + items: + $ref: '#/definitions/retriever.Reference' + type: array + type: object info: contact: {} paths: @@ -355,7 +355,7 @@ paths: description: OK schema: items: - $ref: '#/definitions/chat.Conversation' + $ref: '#/definitions/storage.Conversation' type: array "400": description: Bad Request @@ -415,7 +415,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/chat.Conversation' + $ref: '#/definitions/storage.Conversation' "400": description: Bad Request schema: diff --git a/apiserver/pkg/chat/chat.go b/apiserver/pkg/chat/chat_server.go similarity index 55% rename from apiserver/pkg/chat/chat.go rename to apiserver/pkg/chat/chat_server.go index 1afcbc8d2..48e2b5783 100644 --- a/apiserver/pkg/chat/chat.go +++ b/apiserver/pkg/chat/chat_server.go @@ -20,7 +20,6 @@ import ( "context" "errors" "fmt" - "sync" "time" "github.com/tmc/langchaingo/memory" @@ -31,18 +30,24 @@ import ( "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/apiserver/pkg/auth" + "github.com/kubeagi/arcadia/apiserver/pkg/chat/storage" "github.com/kubeagi/arcadia/apiserver/pkg/client" "github.com/kubeagi/arcadia/pkg/appruntime" "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) -var ( - mu sync.Mutex - Conversations = map[string]Conversation{} -) +type ChatServer struct { + storage storage.Storage +} + +func NewChatServer(storage storage.Storage) *ChatServer { + return &ChatServer{ + storage: storage, + } +} -func AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messageID string) (*ChatRespBody, error) { +func (cs *ChatServer) AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messageID string) (*ChatRespBody, error) { token := auth.ForOIDCToken(ctx) c, err := client.GetClient(token) if err != nil { @@ -61,38 +66,35 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messag if !app.Status.IsReady() { return nil, errors.New("application is not ready") } - var conversation Conversation + var conversation *storage.Conversation currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) if !req.NewChat { - var ok bool - conversation, ok = Conversations[req.ConversationID] - if !ok { - return nil, errors.New("conversation is not found") + search := []storage.SearchOption{ + storage.WithAppName(req.APPName), + storage.WithAppNamespace(req.AppNamespace), + storage.WithDebug(req.Debug), } - if currentUser != "" && currentUser != conversation.User { - return nil, errors.New("conversation id not match with user") + if currentUser != "" { + search = append(search, storage.WithUser(currentUser)) } - if conversation.AppName != req.APPName || conversation.AppNamespce != req.AppNamespace { - return nil, errors.New("conversation id not match with app info") - } - if conversation.Debug != req.Debug { - return nil, errors.New("conversation id not match with debug") + conversation, err = cs.storage.FindExistingConversation(req.ConversationID, search...) + if err != nil { + return nil, err } } else { - conversation = Conversation{ - ID: req.ConversationID, - AppName: req.APPName, - AppNamespce: req.AppNamespace, - StartedAt: time.Now(), - UpdatedAt: time.Now(), - Messages: make([]Message, 0), - History: memory.NewChatMessageHistory(), - User: currentUser, - Debug: req.Debug, + conversation = &storage.Conversation{ + ID: req.ConversationID, + AppName: req.APPName, + AppNamespace: req.AppNamespace, + StartedAt: time.Now(), + UpdatedAt: time.Now(), + Messages: make([]storage.Message, 0), + History: memory.NewChatMessageHistory(), + User: currentUser, + Debug: req.Debug, } } - - conversation.Messages = append(conversation.Messages, Message{ + conversation.Messages = append(conversation.Messages, storage.Message{ ID: messageID, Query: req.Query, Answer: "", @@ -111,9 +113,9 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messag conversation.UpdatedAt = time.Now() conversation.Messages[len(conversation.Messages)-1].Answer = out.Answer conversation.Messages[len(conversation.Messages)-1].References = out.References - mu.Lock() - Conversations[conversation.ID] = conversation - mu.Unlock() + if err := cs.storage.UpdateConversation(conversation); err != nil { + return nil, err + } return &ChatRespBody{ ConversationID: conversation.ID, MessageID: messageID, @@ -123,56 +125,36 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string, messag }, nil } -func ListConversations(ctx context.Context, req APPMetadata) ([]Conversation, error) { - conversations := make([]Conversation, 0) +func (cs *ChatServer) ListConversations(ctx context.Context, req APPMetadata) ([]storage.Conversation, error) { currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - for _, c := range Conversations { - if !c.Debug && c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && (currentUser == "" || currentUser == c.User) { - conversations = append(conversations, c) - } - } - mu.Unlock() - return conversations, nil + return cs.storage.ListConversations(storage.WithAppNamespace(req.AppNamespace), storage.WithAppName(req.APPName), storage.WithUser(currentUser), storage.WithUser(currentUser)) } -func DeleteConversation(ctx context.Context, conversationID string) error { +func (cs *ChatServer) DeleteConversation(ctx context.Context, conversationID string) error { currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - defer mu.Unlock() - c, ok := Conversations[conversationID] - if ok && (currentUser == "" || currentUser == c.User) { - delete(Conversations, c.ID) - return nil - } else { - return errors.New("conversation is not found") - } + return cs.storage.Delete(storage.WithConversationID(conversationID), storage.WithUser(currentUser)) } -func ListMessages(ctx context.Context, req ConversationReqBody) (Conversation, error) { +func (cs *ChatServer) ListMessages(ctx context.Context, req ConversationReqBody) (storage.Conversation, error) { currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - defer mu.Unlock() - for _, c := range Conversations { - if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && req.ConversationID == c.ID && (currentUser == "" || currentUser == c.User) { - return c, nil - } + c, err := cs.storage.FindExistingConversation(req.ConversationID, storage.WithAppNamespace(req.AppNamespace), storage.WithAppName(req.APPName), storage.WithAppNamespace(req.AppNamespace), storage.WithUser(currentUser)) + if err != nil { + return storage.Conversation{}, err } - return Conversation{}, errors.New("conversation is not found") + if c != nil { + return *c, nil + } + return storage.Conversation{}, errors.New("conversation is not found") } -func GetMessageReferences(ctx context.Context, req MessageReqBody) ([]retriever.Reference, error) { +func (cs *ChatServer) GetMessageReferences(ctx context.Context, req MessageReqBody) ([]retriever.Reference, error) { currentUser, _ := ctx.Value(auth.UserNameContextKey).(string) - mu.Lock() - defer mu.Unlock() - for _, c := range Conversations { - if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && c.ID == req.ConversationID && (currentUser == "" || currentUser == c.User) { - for _, m := range c.Messages { - if m.ID == req.MessageID { - return m.References, nil - } - } - } + m, err := cs.storage.FindExistingMessage(req.ConversationID, req.MessageID, storage.WithAppNamespace(req.AppNamespace), storage.WithAppName(req.APPName), storage.WithAppNamespace(req.AppNamespace), storage.WithUser(currentUser)) + if err != nil { + return nil, err + } + if m != nil { + return m.References, nil } return nil, errors.New("conversation or message is not found") } diff --git a/apiserver/pkg/chat/chat_type.go b/apiserver/pkg/chat/rest_type.go similarity index 70% rename from apiserver/pkg/chat/chat_type.go rename to apiserver/pkg/chat/rest_type.go index 348c1ee73..8c1fa202d 100644 --- a/apiserver/pkg/chat/chat_type.go +++ b/apiserver/pkg/chat/rest_type.go @@ -19,8 +19,6 @@ package chat import ( "time" - "github.com/tmc/langchaingo/memory" - "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) @@ -35,7 +33,6 @@ const ( Blocking ResponseMode = "blocking" // Streaming means the response will use Server-Sent Events Streaming ResponseMode = "streaming" - // todo isFlowValidForStream only some node(llm chain) support streaming ) type APPMetadata struct { @@ -80,25 +77,6 @@ type ChatRespBody struct { References []retriever.Reference `json:"references,omitempty"` } -type Conversation struct { - ID string `json:"id" example:"5a41f3ca-763b-41ec-91c3-4bbbb00736d0"` - AppName string `json:"app_name" example:"chat-with-llm"` - AppNamespce string `json:"app_namespace" example:"arcadia"` - StartedAt time.Time `json:"started_at" example:"2023-12-21T10:21:06.389359092+08:00"` - UpdatedAt time.Time `json:"updated_at" example:"2023-12-22T10:21:06.389359092+08:00"` - Messages []Message `json:"messages"` - History *memory.ChatMessageHistory `json:"-"` - User string `json:"-"` - Debug bool `json:"-"` -} - -type Message struct { - ID string `json:"id" example:"4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24"` - Query string `json:"query" example:"旷工最小计算单位为多少天?"` - Answer string `json:"answer" example:"旷工最小计算单位为0.5天。"` - References []retriever.Reference `json:"references,omitempty"` -} - type ErrorResp struct { Err string `json:"error" example:"conversation is not found"` } diff --git a/apiserver/pkg/chat/storage/search.go b/apiserver/pkg/chat/storage/search.go new file mode 100644 index 000000000..8ff0080be --- /dev/null +++ b/apiserver/pkg/chat/storage/search.go @@ -0,0 +1,97 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +type Search struct { + ConversationID *string + MessageID *string + AppName *string + AppNamespace *string + User *string + Debug *bool +} + +type SearchOption func(options *Search) + +func NewSearchOptions(conversationID *string) *Search { + return &Search{ConversationID: conversationID} +} + +func applyOptions(conversationID *string, opts ...SearchOption) *Search { + o := NewSearchOptions(conversationID) + for _, opt := range opts { + opt(o) + } + return o +} + +// WithConversationID returns a Search for setting the ConversationID. +func WithConversationID(id string) SearchOption { + if id == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.ConversationID = &id + } +} + +// WithMessageID returns a Search for setting the MessageID. +func WithMessageID(id string) SearchOption { + if id == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.MessageID = &id + } +} + +// WithAppName returns a Search for setting the AppName. +func WithAppName(name string) SearchOption { + if name == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.AppName = &name + } +} + +// WithAppNamespace returns a Search for setting the AppNamespace. +func WithAppNamespace(name string) SearchOption { + if name == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.AppNamespace = &name + } +} + +// WithUser returns a Search for setting the User. +func WithUser(name string) SearchOption { + if name == "" { + return func(o *Search) {} + } + return func(o *Search) { + o.User = &name + } +} + +// WithDebug returns a Search for setting the Debug. +func WithDebug(debug bool) SearchOption { + return func(o *Search) { + o.Debug = &debug + } +} diff --git a/apiserver/pkg/chat/storage/storage.go b/apiserver/pkg/chat/storage/storage.go new file mode 100644 index 000000000..e1d07a7da --- /dev/null +++ b/apiserver/pkg/chat/storage/storage.go @@ -0,0 +1,98 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import ( + "errors" + "time" + + "github.com/tmc/langchaingo/memory" + "gorm.io/gorm" + + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" +) + +var ( + ErrConversationNotFound = errors.New("conversation is not found") +) + +// Conversation represent a conversation in storage +type Conversation struct { + ID string `gorm:"column:id;primaryKey;type:uuid;comment:conversation id" json:"id" example:"5a41f3ca-763b-41ec-91c3-4bbbb00736d0"` + AppName string `gorm:"column:app_name;type:string;comment:app name" json:"app_name" example:"chat-with-llm"` + AppNamespace string `gorm:"column:app_namespace;type:string;comment:app namespace" json:"app_namespace" example:"arcadia"` + StartedAt time.Time `gorm:"column:started_at;type:time;autoCreateTime;comment:the time the conversation started at" json:"started_at" example:"2023-12-21T10:21:06.389359092+08:00"` + UpdatedAt time.Time `gorm:"column:updated_at;type:time;autoUpdateTime;comment:the time the conversation updated at" json:"updated_at" example:"2023-12-22T10:21:06.389359092+08:00"` + Messages []Message `gorm:"foreignKey:ConversationID" json:"messages"` + History *memory.ChatMessageHistory `gorm:"-" json:"-"` + User string `gorm:"column:user;type:string;comment:the conversation chat user" json:"-"` + Debug bool `gorm:"column:debug;type:bool;comment:debug mode" json:"-"` + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:time;comment:the time the conversation deleted at" json:"-"` +} + +// Message represent a message in storage +type Message struct { + ID string `gorm:"column:id;primaryKey;type:uuid;comment:message id" json:"id" example:"4f3546dd-5404-4bf8-a3bc-4fa3f9a7ba24"` + Query string `gorm:"column:query;type:string;comment:user input" json:"query" example:"旷工最小计算单位为多少天?"` + Answer string `gorm:"column:answer;type:string;comment:ai response" json:"answer" example:"旷工最小计算单位为0.5天。"` + References References `gorm:"column:references;type:json;comment:references" json:"references,omitempty"` + ConversationID string `gorm:"column:conversation_id;type:uuid;comment:conversation id" json:"-"` +} + +type References []retriever.Reference + +func (Conversation) TableName() string { + return "app_chat_conversation" +} + +func (Message) TableName() string { + return "app_chat_message" +} + +type Storage interface { + ConversationStorage + MessageStorage +} + +// ConversationStorage interface +type ConversationStorage interface { + // FindExistingConversation searches for an existing conversation by ConversationID. + // + // ConversationID string, opts ...SearchOption + // *Conversation, error + FindExistingConversation(ID string, opts ...SearchOption) (*Conversation, error) + // Delete deletes a conversation with the given options. + // + // It takes variadic SearchOption parameter(s) and returns an error. + // **not** return error if the conversation is not found + Delete(opts ...SearchOption) error + // UpdateConversation updates the Conversation. + // + // It takes a pointer to a Conversation and returns an error. + UpdateConversation(*Conversation) error + // ListConversations returns a list of conversations based on the provided options. + // + // It accepts SearchOption(s) and returns a slice of Conversation and an error. + ListConversations(opts ...SearchOption) ([]Conversation, error) +} + +type MessageStorage interface { + // FindExistingMessage finds a message in the conversation. + // + // It takes ConversationID, messageID string parameters and returns *Message, error. + FindExistingMessage(ConversationID, messageID string, opts ...SearchOption) (*Message, error) +} diff --git a/apiserver/pkg/chat/storage/storage_memory.go b/apiserver/pkg/chat/storage/storage_memory.go new file mode 100644 index 000000000..df4c3ebfc --- /dev/null +++ b/apiserver/pkg/chat/storage/storage_memory.go @@ -0,0 +1,153 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import "sync" + +var _ Storage = (*MemoryStorage)(nil) + +type MemoryStorage struct { + mu sync.Mutex + conversations map[string]Conversation +} + +// NewMemoryStorage creates a new MemoryStorage instance. +// +// No parameters. +// Returns a pointer to MemoryStorage. +func NewMemoryStorage() *MemoryStorage { + return &MemoryStorage{ + conversations: make(map[string]Conversation), + } +} + +// ListConversations retrieves conversations from MemoryStorage based on the provided options. +// It takes in optional SearchOption(s) and returns a slice of Conversation and an error. +func (m *MemoryStorage) ListConversations(opts ...SearchOption) (conversations []Conversation, err error) { + searchOpt := applyOptions(nil, opts...) + m.mu.Lock() + for _, c := range m.conversations { + if searchOpt.ConversationID != nil && c.ID != *searchOpt.ConversationID { + continue + } + if searchOpt.AppName != nil && c.AppName != *searchOpt.AppName { + continue + } + if searchOpt.AppNamespace != nil && c.AppNamespace != *searchOpt.AppNamespace { + continue + } + if searchOpt.User != nil && c.User != *searchOpt.User { + continue + } + if searchOpt.Debug != nil && c.Debug != *searchOpt.Debug { + continue + } + conversations = append(conversations, c) + } + m.mu.Unlock() + return conversations, nil +} + +// UpdateConversation updates a conversation in the MemoryStorage. +// +// It takes a pointer to a Conversation as a parameter and returns an error. +func (m *MemoryStorage) UpdateConversation(conversation *Conversation) error { + m.mu.Lock() + m.conversations[conversation.ID] = *conversation + m.mu.Unlock() + return nil +} + +func (m *MemoryStorage) FindExistingMessage(conversationID string, messageID string, opts ...SearchOption) (*Message, error) { + conversation, err := m.FindExistingConversation(conversationID, opts...) + if err != nil { + return nil, err + } + for _, v := range conversation.Messages { + v := v + if v.ID == messageID { + return &v, nil + } + } + return nil, nil +} + +// Delete deletes a conversation from MemoryStorage based on the provided options. +// +// Parameter(s): opts ...SearchOption +// Return type(s): error +func (m *MemoryStorage) Delete(opts ...SearchOption) (err error) { + searchOpt := applyOptions(nil, opts...) + var c *Conversation + if searchOpt.ConversationID != nil { + con, ok := m.conversations[*searchOpt.ConversationID] + if !ok { + return nil + } + c = &con + } else { + c, err = m.FindExistingConversation("", opts...) + if err != nil { + return err + } + if c == nil { + return + } + } + if searchOpt.User != nil && c.User != *searchOpt.User { + return + } + if searchOpt.AppName != nil && c.AppName != *searchOpt.AppName { + return + } + if searchOpt.AppNamespace != nil && c.AppNamespace != *searchOpt.AppNamespace { + return + } + if searchOpt.Debug != nil && c.Debug != *searchOpt.Debug { + return + } + m.mu.Lock() + delete(m.conversations, c.ID) + m.mu.Unlock() + return nil +} + +// FindExistingConversation searches for an existing conversation in MemoryStorage. +// +// ConversationID string, opt ...SearchOption. Returns *Conversation, error. +func (m *MemoryStorage) FindExistingConversation(conversationID string, opt ...SearchOption) (*Conversation, error) { + searchOpt := applyOptions(&conversationID, opt...) + m.mu.Lock() + v, ok := m.conversations[*searchOpt.ConversationID] + m.mu.Unlock() + if !ok { + return nil, ErrConversationNotFound + } + if searchOpt.Debug != nil && v.Debug != *searchOpt.Debug { + return nil, ErrConversationNotFound + } + if searchOpt.AppName != nil && v.AppName != *searchOpt.AppName { + return nil, ErrConversationNotFound + } + if searchOpt.AppNamespace != nil && v.AppNamespace != *searchOpt.AppNamespace { + return nil, ErrConversationNotFound + } + if searchOpt.User != nil && v.User != *searchOpt.User { + return nil, ErrConversationNotFound + } + return &v, nil +} diff --git a/apiserver/pkg/chat/storage/storage_postgresql.go b/apiserver/pkg/chat/storage/storage_postgresql.go new file mode 100644 index 000000000..97ebe6ac0 --- /dev/null +++ b/apiserver/pkg/chat/storage/storage_postgresql.go @@ -0,0 +1,191 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package storage + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "log" + "os" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" +) + +func (r *References) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return fmt.Errorf("failed to unmarshal JSONB value:%#v", value) + } + + result := make([]retriever.Reference, 0) + err := json.Unmarshal(bytes, &result) + if err != nil { + return err + } + *r = result + return nil +} + +func (r References) Value() (driver.Value, error) { + if r == nil || len([]retriever.Reference(r)) == 0 { + return nil, nil + } + // return nil, nil + return json.Marshal(r) +} + +var _ Storage = (*PostgreSQLStorage)(nil) + +type PostgreSQLStorage struct { + db *gorm.DB +} + +func (p *PostgreSQLStorage) ListConversations(opts ...SearchOption) ([]Conversation, error) { + searchOpt := applyOptions(nil, opts...) + conversationQuery := Conversation{} + if searchOpt.ConversationID != nil { + conversationQuery.ID = *searchOpt.ConversationID + } + if searchOpt.Debug != nil { + conversationQuery.Debug = *searchOpt.Debug + } + if searchOpt.User != nil { + conversationQuery.User = *searchOpt.User + } + if searchOpt.AppName != nil { + conversationQuery.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + conversationQuery.AppNamespace = *searchOpt.AppNamespace + } + res := make([]Conversation, 0) + tx := p.db.Find(res, conversationQuery).Preload("Messages") + if tx.Error != nil { + return nil, tx.Error + } + return res, nil +} + +func (p *PostgreSQLStorage) UpdateConversation(conversation *Conversation) error { + tx := p.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(conversation) + if tx.Error != nil { + return tx.Error + } + return nil +} + +func NewPostgreSQLStorage(conn *pgx.Conn) (*PostgreSQLStorage, error) { + connPool := stdlib.OpenDB(*conn.Config()) + db, err := gorm.Open(postgres.New(postgres.Config{Conn: connPool}), &gorm.Config{}) + if err != nil { + return nil, err + } + if err := db.AutoMigrate(&Conversation{}, &Message{}); err != nil { + return nil, err + } + log := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: logger.Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + db.Logger = log + return &PostgreSQLStorage{ + db: db, + }, nil +} + +func (p *PostgreSQLStorage) FindExistingConversation(conversationID string, opts ...SearchOption) (*Conversation, error) { + searchOpt := applyOptions(&conversationID, opts...) + conversationQuery := Conversation{ID: conversationID} + if searchOpt.Debug != nil { + conversationQuery.Debug = *searchOpt.Debug + } + if searchOpt.User != nil { + conversationQuery.User = *searchOpt.User + } + if searchOpt.AppName != nil { + conversationQuery.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + conversationQuery.AppNamespace = *searchOpt.AppNamespace + } + res := &Conversation{} + tx := p.db.First(res, conversationQuery).Preload("Messages") + if tx.Error != nil { + return nil, tx.Error + } + return res, nil +} + +func (p *PostgreSQLStorage) Delete(opts ...SearchOption) error { + searchOpt := applyOptions(nil, opts...) + c := Conversation{} + if searchOpt.ConversationID != nil { + c.ID = *searchOpt.ConversationID + } + if searchOpt.User != nil { + c.User = *searchOpt.User + } + if searchOpt.AppName != nil { + c.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + c.AppNamespace = *searchOpt.AppNamespace + } + if searchOpt.Debug != nil { + c.Debug = *searchOpt.Debug + } + tx := p.db.Select("Messages").Delete(c) + if tx.Error != nil { + return tx.Error + } + return nil +} + +func (p *PostgreSQLStorage) FindExistingMessage(conversationID string, messageID string, opts ...SearchOption) (*Message, error) { + searchOpt := applyOptions(&conversationID, opts...) + conversationQuery := Conversation{ID: conversationID} + if searchOpt.Debug != nil { + conversationQuery.Debug = *searchOpt.Debug + } + if searchOpt.User != nil { + conversationQuery.User = *searchOpt.User + } + if searchOpt.AppName != nil { + conversationQuery.AppName = *searchOpt.AppName + } + if searchOpt.AppNamespace != nil { + conversationQuery.AppNamespace = *searchOpt.AppNamespace + } + conversation := &Conversation{} + message := &Message{} + tx := p.db.Find(conversation, conversationQuery).First(message, Message{ID: messageID}) + if tx.Error != nil { + return nil, tx.Error + } + return message, nil +} diff --git a/apiserver/service/chat.go b/apiserver/service/chat.go index e1b35499d..f79da3c04 100644 --- a/apiserver/service/chat.go +++ b/apiserver/service/chat.go @@ -17,6 +17,7 @@ limitations under the License. package service import ( + "context" "errors" "fmt" "io" @@ -26,13 +27,18 @@ import ( "github.com/gin-gonic/gin" "k8s.io/apimachinery/pkg/util/uuid" + "k8s.io/client-go/dynamic" "k8s.io/klog/v2" "github.com/kubeagi/arcadia/apiserver/config" "github.com/kubeagi/arcadia/apiserver/pkg/auth" "github.com/kubeagi/arcadia/apiserver/pkg/chat" + "github.com/kubeagi/arcadia/apiserver/pkg/chat/storage" + "github.com/kubeagi/arcadia/apiserver/pkg/client" "github.com/kubeagi/arcadia/apiserver/pkg/oidc" "github.com/kubeagi/arcadia/apiserver/pkg/requestid" + pkgconfig "github.com/kubeagi/arcadia/pkg/config" + "github.com/kubeagi/arcadia/pkg/datasource" ) const ( @@ -40,8 +46,40 @@ const ( WaitTimeoutForChatStreaming = 5 ) +type ChatService struct { + server *chat.ChatServer +} + // @BasePath /chat +func NewChatService(cli dynamic.Interface) (*ChatService, error) { + ctx := context.TODO() + ds, err := pkgconfig.GetRelationalDatasource(ctx, nil, cli) + if err != nil || ds == nil { + if err != nil { + klog.Infof("get relational datasource failed: %s, use memory storage for chat\n", err.Error()) + } else if ds == nil { + klog.Infoln("no relational datasource found, use memory storage for chat") + } + return &ChatService{chat.NewChatServer(storage.NewMemoryStorage())}, nil + } + pg, err := datasource.GetPostgreSQLPool(ctx, nil, cli, ds) + if err != nil { + klog.Errorf("get postgresql pool failed : %s", err.Error()) + return nil, err + } + conn, err := pg.Pool.Acquire(ctx) + if err != nil { + return nil, err + } + db, err := storage.NewPostgreSQLStorage(conn.Conn()) + if err != nil { + return nil, err + } + klog.Infoln("use pg as chat storage.") + return &ChatService{chat.NewChatServer(db)}, nil +} + // @Summary chat with application // @Schemes // @Description chat with application @@ -54,7 +92,7 @@ const ( // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router / [post] -func chatHandler() gin.HandlerFunc { +func (cs *ChatService) ChatHandler() gin.HandlerFunc { return func(c *gin.Context) { req := chat.ChatReqBody{} if err := c.ShouldBindJSON(&req); err != nil { @@ -85,7 +123,7 @@ func chatHandler() gin.HandlerFunc { } } }() - response, err = chat.AppRun(c.Request.Context(), req, respStream, messageID) + response, err = cs.server.AppRun(c.Request.Context(), req, respStream, messageID) if err != nil { c.SSEvent("error", chat.ChatRespBody{ MessageID: messageID, @@ -151,7 +189,7 @@ func chatHandler() gin.HandlerFunc { klog.FromContext(c.Request.Context()).Info("end to receive messages") } else { // handle chat blocking mode - response, err = chat.AppRun(c.Request.Context(), req, nil, messageID) + response, err = cs.server.AppRun(c.Request.Context(), req, nil, messageID) if err != nil { c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) klog.FromContext(c.Request.Context()).Error(err, "error resp") @@ -170,11 +208,11 @@ func chatHandler() gin.HandlerFunc { // @Accept json // @Produce json // @Param request body chat.APPMetadata true "query params" -// @Success 200 {object} []chat.Conversation +// @Success 200 {object} []storage.Conversation // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /conversations [post] -func listConversationHandler() gin.HandlerFunc { +func (cs *ChatService) ListConversationHandler() gin.HandlerFunc { return func(c *gin.Context) { req := chat.APPMetadata{} if err := c.ShouldBindJSON(&req); err != nil { @@ -182,7 +220,7 @@ func listConversationHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - resp, err := chat.ListConversations(c, req) + resp, err := cs.server.ListConversations(c, req) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error list conversation") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -204,7 +242,7 @@ func listConversationHandler() gin.HandlerFunc { // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /conversations/:conversationID [delete] -func deleteConversationHandler() gin.HandlerFunc { +func (cs *ChatService) DeleteConversationHandler() gin.HandlerFunc { return func(c *gin.Context) { conversationID := c.Param("conversationID") if conversationID == "" { @@ -213,7 +251,7 @@ func deleteConversationHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - err := chat.DeleteConversation(c, conversationID) + err := cs.server.DeleteConversation(c, conversationID) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error delete conversation") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -231,11 +269,11 @@ func deleteConversationHandler() gin.HandlerFunc { // @Accept json // @Produce json // @Param request body chat.ConversationReqBody true "query params" -// @Success 200 {object} chat.Conversation +// @Success 200 {object} storage.Conversation // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /messages [post] -func historyHandler() gin.HandlerFunc { +func (cs *ChatService) HistoryHandler() gin.HandlerFunc { return func(c *gin.Context) { req := chat.ConversationReqBody{} if err := c.ShouldBindJSON(&req); err != nil { @@ -243,7 +281,7 @@ func historyHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - resp, err := chat.ListMessages(c, req) + resp, err := cs.server.ListMessages(c, req) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error list messages") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -266,7 +304,7 @@ func historyHandler() gin.HandlerFunc { // @Failure 400 {object} chat.ErrorResp // @Failure 500 {object} chat.ErrorResp // @Router /messages/:messageID/references [post] -func referenceHandler() gin.HandlerFunc { +func (cs *ChatService) ReferenceHandler() gin.HandlerFunc { return func(c *gin.Context) { messageID := c.Param("messageID") if messageID == "" { @@ -283,7 +321,7 @@ func referenceHandler() gin.HandlerFunc { c.JSON(http.StatusBadRequest, chat.ErrorResp{Err: err.Error()}) return } - resp, err := chat.GetMessageReferences(c, req) + resp, err := cs.server.GetMessageReferences(c, req) if err != nil { klog.FromContext(c.Request.Context()).Error(err, "error get message references") c.JSON(http.StatusInternalServerError, chat.ErrorResp{Err: err.Error()}) @@ -295,11 +333,21 @@ func referenceHandler() gin.HandlerFunc { } func registerChat(g *gin.RouterGroup, conf config.ServerConfig) { - g.POST("", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatHandler()) // chat with bot + c, err := client.GetClient(nil) + if err != nil { + panic(err) + } + + chatService, err := NewChatService(c) + if err != nil { + panic(err) + } + + g.POST("", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.ChatHandler()) // chat with bot - g.POST("/conversations", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), listConversationHandler()) // list conversations - g.DELETE("/conversations/:conversationID", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), deleteConversationHandler()) // delete conversation + g.POST("/conversations", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.ListConversationHandler()) // list conversations + g.DELETE("/conversations/:conversationID", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.DeleteConversationHandler()) // delete conversation - g.POST("/messages", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), historyHandler()) // messages history - g.POST("/messages/:messageID/references", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), referenceHandler()) // messages reference + g.POST("/messages", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.HistoryHandler()) // messages history + g.POST("/messages/:messageID/references", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, "get", "applications"), requestid.RequestIDInterceptor(), chatService.ReferenceHandler()) // messages reference } diff --git a/deploy/charts/arcadia/templates/config.yaml b/deploy/charts/arcadia/templates/config.yaml index e2919746a..de03c8802 100644 --- a/deploy/charts/arcadia/templates/config.yaml +++ b/deploy/charts/arcadia/templates/config.yaml @@ -6,6 +6,13 @@ data: kind: Datasource name: '{{ .Release.Name }}-minio' namespace: '{{ .Release.Namespace }}' +{{- if .Values.postgresql.enabled }} + relationalDatasource: + apiGroup: arcadia.kubeagi.k8s.com.cn/v1alpha1 + kind: Datasource + name: {{ .Release.Name }}-postgresql + namespace: {{ .Release.Namespace }} +{{- end }} {{- if gt (len .Values.ray.clusters) 0 }} rayClusters: {{- range .Values.ray.clusters }} diff --git a/go.mod b/go.mod index fe28fbb6a..3303431cc 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/go-logr/logr v1.2.3 github.com/gofiber/fiber/v2 v2.50.0 github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/jackc/pgx/v5 v5.4.1 + github.com/jackc/pgx/v5 v5.4.3 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.3 github.com/r3labs/sse/v2 v2.10.0 @@ -26,6 +26,8 @@ require ( github.com/tmc/langchaingo v0.1.3 github.com/valyala/fasthttp v1.50.0 github.com/vektah/gqlparser/v2 v2.5.10 + gorm.io/driver/postgres v1.5.4 + gorm.io/gorm v1.25.5 k8s.io/api v0.24.2 k8s.io/apimachinery v0.24.2 k8s.io/client-go v0.24.2 @@ -57,7 +59,9 @@ require ( github.com/huandu/xstrings v1.3.3 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/puddle/v2 v2.2.0 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/minio/md5-simd v1.1.2 // indirect diff --git a/go.sum b/go.sum index 382b1d5db..98345b58c 100644 --- a/go.sum +++ b/go.sum @@ -432,13 +432,15 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.4.1 h1:oKfB/FhuVtit1bBM3zNRRsZ925ZkMN3HXL+LgLUM9lE= -github.com/jackc/pgx/v5 v5.4.1/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY= -github.com/jackc/puddle/v2 v2.2.0 h1:RdcDk92EJBuBS55nQMMYFXTxwstHug4jkhT5pq8VxPk= -github.com/jackc/puddle/v2 v2.2.0/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= @@ -1282,6 +1284,10 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/pkg/config/config.go b/pkg/config/config.go index 2b7c48572..8e31d216e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -48,18 +48,9 @@ var ( ErrNoConfigRayClusters = fmt.Errorf("config RayClusters in comfigmap is not found") ) -func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { - config, err := GetConfig(ctx, c, cli) - if err != nil { - return nil, err - } - name := config.SystemDatasource.Name - var namespace string - if config.SystemDatasource.Namespace != nil { - namespace = *config.SystemDatasource.Namespace - } else { - namespace = utils.GetCurrentNamespace() - } +func getDatasource(ctx context.Context, ref arcadiav1alpha1.TypedObjectReference, c client.Client, cli dynamic.Interface) (ds *arcadiav1alpha1.Datasource, err error) { + name := ref.Name + namespace := ref.GetNamespace(utils.GetCurrentNamespace()) source := &arcadiav1alpha1.Datasource{} if c != nil { if err = c.Get(ctx, client.ObjectKey{Name: name, Namespace: namespace}, source); err != nil { @@ -79,6 +70,22 @@ func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Inter return source, err } +func GetSystemDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { + config, err := GetConfig(ctx, c, cli) + if err != nil { + return nil, err + } + return getDatasource(ctx, config.SystemDatasource, c, cli) +} + +func GetRelationalDatasource(ctx context.Context, c client.Client, cli dynamic.Interface) (*arcadiav1alpha1.Datasource, error) { + config, err := GetConfig(ctx, c, cli) + if err != nil { + return nil, err + } + return getDatasource(ctx, config.RelationalDatasource, c, cli) +} + func GetGateway(ctx context.Context, c client.Client, cli dynamic.Interface) (*Gateway, error) { config, err := GetConfig(ctx, c, cli) if err != nil { diff --git a/pkg/config/config_type.go b/pkg/config/config_type.go index 71d6d85d8..fcdd0a342 100644 --- a/pkg/config/config_type.go +++ b/pkg/config/config_type.go @@ -25,6 +25,9 @@ type Config struct { // SystemDatasource specifies the built-in datasource for Arcadia to host data files and model files SystemDatasource arcadiav1alpha1.TypedObjectReference `json:"systemDatasource,omitempty"` + // RelationalDatasource specifies the built-in datasource(common:postgres) for Arcadia to host relational data + RelationalDatasource arcadiav1alpha1.TypedObjectReference `json:"relationalDatasource,omitempty"` + // Gateway to access LLM api services Gateway *Gateway `json:"gateway,omitempty"` diff --git a/pkg/datasource/postgresql.go b/pkg/datasource/postgresql.go index 6486ba0ff..f264f6ca7 100644 --- a/pkg/datasource/postgresql.go +++ b/pkg/datasource/postgresql.go @@ -31,7 +31,7 @@ import ( var ( _ Datasource = (*PostgreSQL)(nil) - locker sync.Mutex + pgEnvMutex sync.Mutex poolsMutex sync.Mutex pools = make(map[string]*PostgreSQL) ) @@ -88,8 +88,8 @@ func newPostgreSQL(ctx context.Context, c client.Client, dc dynamic.Interface, c pgPassFile = string(data[v1alpha1.PGPASSFILE]) pgSSLPassword = string(data[v1alpha1.PGSSLPASSWORD]) } - locker.Lock() - defer locker.Unlock() + pgEnvMutex.Lock() + defer pgEnvMutex.Unlock() if pgUser != "" { if err := os.Setenv("PGUSER", pgUser); err != nil { return nil, err