diff --git a/cypress/e2e/data_layer/main.py b/cypress/e2e/data_layer/main.py index 21c682e836..e1752215fc 100644 --- a/cypress/e2e/data_layer/main.py +++ b/cypress/e2e/data_layer/main.py @@ -1,6 +1,9 @@ +import os.path +import pickle from typing import Dict, List, Optional import chainlit.data as cl_data +from chainlit.socket import persist_user_session from chainlit.step import StepDict from literalai.helper import utc_now @@ -8,9 +11,6 @@ now = utc_now() -create_step_counter = 0 - - thread_history = [ { "id": "test1", @@ -61,6 +61,22 @@ ] # type: List[cl_data.ThreadDict] deleted_thread_ids = [] # type: List[str] +THREAD_HISTORY_PICKLE_PATH = os.getenv("THREAD_HISTORY_PICKLE_PATH") +if THREAD_HISTORY_PICKLE_PATH and os.path.exists(THREAD_HISTORY_PICKLE_PATH): + with open(THREAD_HISTORY_PICKLE_PATH, "rb") as f: + thread_history = pickle.load(f) + + +async def save_thread_history(): + if THREAD_HISTORY_PICKLE_PATH: + # Force saving of thread history for reload when server restarts + await persist_user_session( + cl.context.session.thread_id, cl.context.session.to_persistable() + ) + + with open(THREAD_HISTORY_PICKLE_PATH, "wb") as out_file: + pickle.dump(thread_history, out_file) + class TestDataLayer(cl_data.BaseDataLayer): async def get_user(self, identifier: str): @@ -101,8 +117,9 @@ async def update_thread( @cl_data.queue_until_user_message() async def create_step(self, step_dict: StepDict): - global create_step_counter - create_step_counter += 1 + cl.user_session.set( + "create_step_counter", cl.user_session.get("create_step_counter") + 1 + ) thread = next( (t for t in thread_history if t["id"] == step_dict.get("threadId")), None @@ -138,11 +155,14 @@ async def delete_thread(self, thread_id: str): async def send_count(): + create_step_counter = cl.user_session.get("create_step_counter") await cl.Message(f"Create step counter: {create_step_counter}").send() @cl.on_chat_start async def main(): + # Add step counter to session so that it is saved in thread metadata + cl.user_session.set("create_step_counter", 0) await cl.Message("Hello, send me a message!").send() await send_count() @@ -157,6 +177,8 @@ async def handle_message(): await cl.Message("Ok!").send() await send_count() + await save_thread_history() + @cl.password_auth_callback def auth_callback(username: str, password: str) -> Optional[cl.User]: diff --git a/cypress/e2e/data_layer/spec.cy.ts b/cypress/e2e/data_layer/spec.cy.ts index 4f16e8dc8b..0320e89e98 100644 --- a/cypress/e2e/data_layer/spec.cy.ts +++ b/cypress/e2e/data_layer/spec.cy.ts @@ -1,4 +1,7 @@ +import { sep } from 'path'; + import { runTestServer, submitMessage } from '../../support/testUtils'; +import { ExecutionMode } from '../../support/utils'; function login() { cy.get("[id='email']").type('admin'); @@ -71,9 +74,50 @@ function resumeThread() { cy.get('.step').eq(8).should('contain', 'chat_profile'); } +function restartServer( + mode: ExecutionMode = undefined, + env?: Record +) { + const pathItems = Cypress.spec.absolute.split(sep); + const testName = pathItems[pathItems.length - 2]; + cy.exec(`pnpm exec ts-node ./cypress/support/run.ts ${testName} ${mode}`, { + env + }); +} + +function continueThread() { + cy.get('.step').eq(7).should('contain', 'Welcome back to Hello'); + + submitMessage('Hello after restart'); + + // Verify that new step counter messages have been added + cy.get('.step').eq(11).should('contain', 'Create step counter: 14'); + cy.get('.step').eq(14).should('contain', 'Create step counter: 17'); +} + +function newThread() { + cy.get('#new-chat-button').click(); + cy.get('#confirm').click(); +} + describe('Data Layer', () => { - before(() => { - runTestServer(); + beforeEach(() => { + // Set up the thread history file + const pathItems = Cypress.spec.absolute.split(sep); + pathItems[pathItems.length - 1] = 'thread_history.pickle'; + const threadHistoryFile = pathItems.join(sep); + cy.wrap(threadHistoryFile).as('threadHistoryFile'); + + runTestServer(undefined, { + THREAD_HISTORY_PICKLE_PATH: threadHistoryFile + }); + }); + + afterEach(() => { + cy.get('@threadHistoryFile').then((threadHistoryFile) => { + // Clean up the thread history file + cy.exec(`rm ${threadHistoryFile}`); + }); }); describe('Data Features with persistence', () => { @@ -84,5 +128,24 @@ describe('Data Layer', () => { threadList(); resumeThread(); }); + + it('should continue the thread after backend restarts and work with new thread as usual', () => { + login(); + feedback(); + threadQueue(); + + cy.get('@threadHistoryFile').then((threadHistoryFile) => { + restartServer(undefined, { + THREAD_HISTORY_PICKLE_PATH: `${threadHistoryFile}` + }); + }); + // Continue the thread and verify that the step counter is not reset + continueThread(); + + // Create a new thread and verify that the step counter is reset + newThread(); + feedback(); + threadQueue(); + }); }); }); diff --git a/libs/react-client/src/useChatSession.ts b/libs/react-client/src/useChatSession.ts index 6ff4e77bb8..ffa817c586 100644 --- a/libs/react-client/src/useChatSession.ts +++ b/libs/react-client/src/useChatSession.ts @@ -1,5 +1,5 @@ import { debounce } from 'lodash'; -import { useCallback, useContext } from 'react'; +import { useCallback, useContext, useEffect } from 'react'; import { useRecoilState, useRecoilValue, @@ -63,7 +63,17 @@ const useChatSession = () => { const setTokenCount = useSetRecoilState(tokenCountState); const [chatProfile, setChatProfile] = useRecoilState(chatProfileState); const idToResume = useRecoilValue(threadIdToResumeState); - const setCurrentThreadId = useSetRecoilState(currentThreadIdState); + const [currentThreadId, setCurrentThreadId] = + useRecoilState(currentThreadIdState); + + // Use currentThreadId as thread id in websocket header + useEffect(() => { + if (session?.socket) { + session.socket.io.opts.extraHeaders!['X-Chainlit-Thread-Id'] = + currentThreadId || ''; + } + }, [currentThreadId]); + const _connect = useCallback( ({ userEnv,