Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: regenerate in Chat, agent and Chatflow app #7661

Merged
merged 34 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f458bae
feat(wip): regenerate icon
xuzuodong Aug 22, 2024
3fb46a4
feat: regenerate chat message for and app
xuzuodong Aug 26, 2024
7719abc
chore: lint python code
xuzuodong Aug 26, 2024
2d4e9c9
Merge branch 'main' into feat/chat-regenerate
xuzuodong Aug 26, 2024
f15004c
chore: adapt new Tooltip component
xuzuodong Aug 26, 2024
9c394b2
chore: finish basic logic
xuzuodong Aug 27, 2024
45d540c
lint python code
xuzuodong Aug 27, 2024
b03403c
chore: remove debug code
xuzuodong Aug 27, 2024
80fd810
chore: add line break
xuzuodong Aug 27, 2024
19e01d3
chore: allow getting message list by desc to avoid breaking change to…
xuzuodong Aug 27, 2024
37ea507
chore: add missing field to related endpoint
xuzuodong Aug 27, 2024
8b59da0
chore: hide regenerate button in log view
xuzuodong Aug 29, 2024
a304d89
Merge branch 'main' into feat/chat-regenerate
xuzuodong Aug 29, 2024
62f5236
chore: update db migrate script
xuzuodong Aug 29, 2024
6ec8755
lint python code and remove unnecessary code
xuzuodong Aug 29, 2024
fcd178e
fix: generation api error
xuzuodong Aug 29, 2024
3475486
fix: regenerate not working in chatflow and embed app
xuzuodong Aug 30, 2024
e7c58c1
chore: minor change
xuzuodong Aug 30, 2024
53e0865
chore: add field to all non-first messages
xuzuodong Sep 1, 2024
863e595
chor: update comments
xuzuodong Sep 2, 2024
193b5d9
Merge remote-tracking branch 'origin/main' into feat/chat-regenerate
xuzuodong Sep 4, 2024
3a5c0b3
chore: update db mirgration code
xuzuodong Sep 4, 2024
823c592
chore: lint python code
xuzuodong Sep 4, 2024
f89ca21
refactor: use only
xuzuodong Sep 7, 2024
c32f273
style: regenerate icon
xuzuodong Sep 7, 2024
5b6a03c
chore: lint python code
xuzuodong Sep 7, 2024
1a407f1
Merge remote-tracking branch 'origin/main' into feat/chat-regenerate
xuzuodong Sep 11, 2024
c585fae
chore: lint python code
xuzuodong Sep 11, 2024
c7aa2b3
chore: db migrate version
xuzuodong Sep 11, 2024
1d68b58
chore: lint python code
xuzuodong Sep 11, 2024
5606c17
chore: change icon of regenerate button
xuzuodong Sep 14, 2024
c4a13c7
chore: only show latest thread messages in log list and chatflow run …
xuzuodong Sep 14, 2024
421569a
Merge remote-tracking branch 'origin/main' into feat/chat-regenerate
xuzuodong Sep 14, 2024
e07b550
chore: lint python code
xuzuodong Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
1 change: 1 addition & 0 deletions api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def post(self, app_model):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
Expand Down
2 changes: 0 additions & 2 deletions api/controllers/console/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def get(self, app_model):
if rest_count > 0:
has_more = True

history_messages = list(reversed(history_messages))

return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)


Expand Down
2 changes: 2 additions & 0 deletions api/controllers/console/app/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def post(self, app_model: App):
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")

args = parser.parse_args()

try:
Expand Down
1 change: 1 addition & 0 deletions api/controllers/console/explore/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def post(self, installed_app):
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/explore/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get(self, installed_app):

try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
Expand Down
1 change: 1 addition & 0 deletions api/controllers/service_api/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MessageListApi(Resource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
Expand Down
1 change: 1 addition & 0 deletions api/controllers/web/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def post(self, app_model, end_user):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")

args = parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion api/controllers/web/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
Expand Down Expand Up @@ -89,7 +90,7 @@ def get(self, app_model, end_user):

try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
Expand Down
5 changes: 4 additions & 1 deletion api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
Expand Down Expand Up @@ -441,10 +442,12 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.order_by(Message.created_at.desc())
.all()
)

messages = list(reversed(extract_thread_messages(messages)))

for message in messages:
if message.id == self.message.id:
continue
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def generate(
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def generate(
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def generate(
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/message_based_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _init_generate_records(
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
Expand Down
3 changes: 3 additions & 0 deletions api/core/app/entities/app_invoke_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""

conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None


class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
Expand All @@ -138,6 +139,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""

conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None


class AdvancedChatAppGenerateEntity(AppGenerateEntity):
Expand All @@ -149,6 +151,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
app_config: WorkflowUIBasedAppConfig

conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None
query: str

class SingleIterationRunEntity(BaseModel):
Expand Down
21 changes: 18 additions & 3 deletions api/core/memory/token_buffer_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
Expand All @@ -33,8 +34,17 @@ def get_history_prompt_messages(

# fetch limited messages, and return reversed
query = (
db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id)
.filter(Message.conversation_id == self.conversation.id, Message.answer != "")
db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id,
Message.parent_message_id,
)
.filter(
Message.conversation_id == self.conversation.id,
xuzuodong marked this conversation as resolved.
Show resolved Hide resolved
)
.order_by(Message.created_at.desc())
)

Expand All @@ -45,7 +55,12 @@ def get_history_prompt_messages(

messages = query.limit(message_limit).all()

messages = list(reversed(messages))
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
thread_messages = extract_thread_messages(messages)
thread_messages.pop(0)
messages = list(reversed(thread_messages))

message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
prompt_messages = []
for message in messages:
Expand Down
22 changes: 22 additions & 0 deletions api/core/prompt/utils/extract_thread_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from constants import UUID_NIL


def extract_thread_messages(messages: list[dict]) -> list[dict]:
thread_messages = []
next_message = None

for message in messages:
if not message.parent_message_id:
# If the message is regenerated and does not have a parent message, it is the start of a new thread
thread_messages.append(message)
break

if not next_message:
thread_messages.append(message)
next_message = message.parent_message_id
else:
if next_message in {message.id, UUID_NIL}:
thread_messages.append(message)
next_message = message.parent_message_id

return thread_messages
1 change: 1 addition & 0 deletions api/fields/conversation_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def format(self, value):
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
}

feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
Expand Down
1 change: 1 addition & 0 deletions api/fields/message_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""add parent_message_id to messages

Revision ID: d57ba9ebb251
Revises: 675b5321501b
Create Date: 2024-09-11 10:12:45.826265

"""
import sqlalchemy as sa
from alembic import op

import models as models

# revision identifiers, used by Alembic.
revision = 'd57ba9ebb251'
down_revision = '675b5321501b'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True))

# Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs
op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL')

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_column('parent_message_id')

# ### end Alembic commands ###
1 change: 1 addition & 0 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ class Message(db.Model):
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
parent_message_id = db.Column(StringUUID, nullable=True)
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255), nullable=False)
Expand Down
4 changes: 3 additions & 1 deletion api/services/message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def pagination_by_first_id(
conversation_id: str,
first_id: Optional[str],
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
Expand Down Expand Up @@ -91,7 +92,8 @@ def pagination_by_first_id(
if rest_count > 0:
has_more = True

history_messages = list(reversed(history_messages))
if order == "asc":
history_messages = list(reversed(history_messages))

return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)

Expand Down
91 changes: 91 additions & 0 deletions api/tests/unit_tests/core/prompt/test_extract_thread_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from uuid import uuid4

from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages


class TestMessage:
def __init__(self, id, parent_message_id):
self.id = id
self.parent_message_id = parent_message_id

def __getitem__(self, item):
return getattr(self, item)


def test_extract_thread_messages_single_message():
messages = [TestMessage(str(uuid4()), UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0] == messages[0]


def test_extract_thread_messages_linear_thread():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id3),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 5
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]


def test_extract_thread_messages_branched_thread():
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id4, id2, id1]


def test_extract_thread_messages_empty_list():
messages = []
result = extract_thread_messages(messages)
assert len(result) == 0


def test_extract_thread_messages_partially_loaded():
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, id0),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]


def test_extract_thread_messages_legacy_messages():
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, UUID_NIL),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]


def test_extract_thread_messages_mixed_with_legacy_messages():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 4
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
Loading
Loading