diff --git a/backend/app/alembic/versions/830fd9c44f39_.py b/backend/app/alembic/versions/830fd9c44f39_.py index 222afb45..b80cef52 100644 --- a/backend/app/alembic/versions/830fd9c44f39_.py +++ b/backend/app/alembic/versions/830fd9c44f39_.py @@ -32,11 +32,20 @@ def upgrade(): "origin", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=True ), ) + op.add_column( + "chat_messages", + sa.Column( + "post_verification_result_url", + sqlmodel.sql.sqltypes.AutoString(length=512), + nullable=True, + ), + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("chat_messages", "post_verification_result_url") op.drop_column("feedbacks", "origin") op.drop_column("chats", "origin") # ### end Alembic commands ### diff --git a/backend/app/models/chat_message.py b/backend/app/models/chat_message.py index cc8a667d..fe9a1fd1 100644 --- a/backend/app/models/chat_message.py +++ b/backend/app/models/chat_message.py @@ -38,5 +38,9 @@ class ChatMessage(UpdatableBaseModel, table=True): "primaryjoin": "ChatMessage.user_id == User.id", }, ) + post_verification_result_url: Optional[str] = Field( + max_length=512, + nullable=True, + ) __tablename__ = "chat_messages" diff --git a/backend/app/rag/chat.py b/backend/app/rag/chat.py index 3d16f088..56e1b2e7 100644 --- a/backend/app/rag/chat.py +++ b/backend/app/rag/chat.py @@ -3,6 +3,7 @@ from uuid import UUID from typing import List, Generator, Optional, Tuple from datetime import datetime, UTC +from urllib.parse import urljoin import requests import jinja2 @@ -421,18 +422,6 @@ def _get_llamaindex_callback_manager(): if not response_text: raise Exception("Got empty response from LLM") - db_assistant_message.sources = source_documents - db_assistant_message.graph_data = graph_data_source_ids - db_assistant_message.content = response_text - db_assistant_message.updated_at = datetime.now(UTC) - db_assistant_message.finished_at = datetime.now(UTC) - self.db_session.add(db_assistant_message) - db_user_message.graph_data = graph_data_source_ids - db_user_message.updated_at = datetime.now(UTC) - db_user_message.finished_at = datetime.now(UTC) - self.db_session.add(db_user_message) - self.db_session.commit() - yield ChatEvent( event_type=ChatEventType.MESSAGE_ANNOTATIONS_PART, payload=ChatStreamMessagePayload( @@ -440,13 +429,26 @@ def _get_llamaindex_callback_manager(): ), ) - self._post_verification( + post_verification_result_url = self._post_verification( self.user_question, response_text, self.db_chat_obj.id, db_assistant_message.id, ) + db_assistant_message.sources = source_documents + db_assistant_message.graph_data = graph_data_source_ids + db_assistant_message.content = response_text + db_assistant_message.post_verification_result_url = post_verification_result_url + db_assistant_message.updated_at = datetime.now(UTC) + db_assistant_message.finished_at = datetime.now(UTC) + self.db_session.add(db_assistant_message) + db_user_message.graph_data = graph_data_source_ids + db_user_message.updated_at = datetime.now(UTC) + db_user_message.finished_at = datetime.now(UTC) + self.db_session.add(db_user_message) + self.db_session.commit() + yield ChatEvent( event_type=ChatEventType.DATA_PART, payload=ChatStreamDataPayload( @@ -498,7 +500,8 @@ def _get_source_documents(self, response: StreamingResponse) -> List[dict]: def _post_verification( self, user_question: str, response_text: str, chat_id: UUID, message_id: int - ): + ) -> Optional[str]: + # post verification to external service, will return the post verification result url post_verification_url = self.chat_engine_config.post_verification_url post_verification_token = self.chat_engine_config.post_verification_token @@ -522,6 +525,8 @@ def _post_verification( timeout=10, ) resp.raise_for_status() + job_id = resp.json()["job_id"] + return urljoin(f"{post_verification_url}/", str(job_id)) except Exception: logger.exception("Failed to post verification")