-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
134 lines (111 loc) · 4.23 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
import os
import chainlit as cl
openai.api_key = os.environ["OPENAI_API_KEY"]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
system_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
ALWAYS return a "SOURCES" part in your answer.
The "SOURCES" part should be a reference to the source of the document from which you got your answer.
Example of your response should be:
```
The answer is foo
SOURCES: xyz
```
Begin!
----------------
{summaries}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
@cl.on_chat_start
async def init():
files = None # list (name, path, size, type, content)
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a text file to begin!",
max_files=2,
accept=["text/plain"]
).send()
print(files)
file = files[0]
msg = cl.Message(content=f"Processing `{file.name}`...")
await msg.send()
# Decode the file
text = file.content.decode("utf-8")
print(text)
# Split the text into chunks
texts = text_splitter.split_text(text)
# Create a metadata for each chunk
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))]
print(metadatas)
# Create a Chroma vector store
embeddings = OpenAIEmbeddings()
docsearch = await cl.make_async(Chroma.from_texts)(
texts, embeddings, metadatas=metadatas
)
# Create a chain that uses the Chroma vector store
chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0, streaming=True),
chain_type="stuff",
retriever=docsearch.as_retriever(),
)
# Save the metadata and texts in the user session
cl.user_session.set("metadatas", metadatas)
cl.user_session.set("texts", texts)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done. \nYou can now ask questions!"
await msg.update()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = await chain.acall(message, callbacks=[cb]) # auestion, answer, sources
print(res)
answer = res["answer"]
sources = res["sources"].strip()
source_elements = []
# Get the metadata and texts from the user session
metadatas = cl.user_session.get("metadatas")
all_sources = [m["source"] for m in metadatas]
texts = cl.user_session.get("texts")
if sources:
found_sources = []
# Add the sources to the message
for source in sources.split(","):
source_name = source.strip().replace(".", "")
# Get the index of the source
try:
index = all_sources.index(source_name)
except ValueError:
continue
text = texts[index]
found_sources.append(source_name)
# Create the text element referenced in the message
source_elements.append(cl.Text(content=text, name=source_name))
if found_sources:
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
if cb.has_streamed_final_answer:
cb.final_stream.elements = source_elements
await cb.final_stream.update()
else:
await cl.Message(content=answer, elements=source_elements).send()