-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
166 lines (133 loc) · 4.79 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from llama_parse import LlamaParse
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from pathlib import Path
import streamlit as st
import joblib
import time
# from dotenv import load_dotenv
import os
# load_dotenv()
LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {input}
Make sure to explain the answer in great detail, covering as much information as possible
"""
def iter_files_pdf(directory: str):
"""Generator to iterate over all pdf files in a directory."""
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path) and filename.endswith('.pdf'):
yield filename
def iter_files_pkl(directory: str):
"""Generator to iterate over all pickle files in a directory."""
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path) and filename.endswith(".pkl"):
yield filename
def parsed_doc(directory: str, parsing_instruction: str, LLAMA_CLOUD_API_KEY: str):
"""Returns parsed documents for each file"""
parser = LlamaParse(
api_key=LLAMA_CLOUD_API_KEY,
parsing_instruction=parsing_instruction,
result_type="markdown"
)
document = parser.load_data(directory)
return document
def set_custom_prompt(custom_prompt_template: str):
"""
Prompt template for QA retrieval for each vectorstore
"""
prompt = PromptTemplate(
template=custom_prompt_template,
input_variables=['context', 'question']
)
return prompt
def parse_pdfs(uploaded_files, parsing_instruction: str):
"""
Parse all pdfs in the current directory
Args:
uploaded_files: list of uploaded files
parsing_instruction: instruction for parsing the PDFs
"""
for uploaded_file in uploaded_files:
if not uploaded_file.name.endswith('.pdf'):
filename = 'TEMP_' + uploaded_file.name + '.pdf'
else:
filename = 'TEMP_' + uploaded_file.name
st.session_state.files_list.append(filename)
save_path = os.path.join(os.getcwd(), filename)
# if file is already saved, don't save it again
if os.path.exists(save_path):
print(f"\"{uploaded_file.name}\" already exists!")
continue
else:
print(f"\"{uploaded_file.name}\" does not exist! saving...")
with open(save_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with st.spinner(text="Parsing the PDFs..."):
directory = os.getcwd()
for filename in iter_files_pdf(directory=directory):
document = parsed_doc(
directory=filename,
parsing_instruction=parsing_instruction,
LLAMA_CLOUD_API_KEY=LLAMA_CLOUD_API_KEY
)
joblib.dump(document, filename[:-4] + ".pkl")
time.sleep(5)
Path('TEMP_output.md').touch()
# add the saved pickles to output.md
for filename in iter_files_pkl(directory=directory):
docs = joblib.load(filename)
with open('TEMP_output.md', 'a', encoding='utf-8') as f:
for doc in docs:
f.write(doc.text + '\n')
def load_output():
"""
Load output.md
"""
with st.spinner(text="Loading output.md..."):
loader = UnstructuredMarkdownLoader('TEMP_output.md')
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=100
)
docs = text_splitter.split_documents(documents)
embed_model = FastEmbedEmbeddings()
vectorstore = Chroma.from_documents(
documents=docs,
collection_name='pdf_collection',
embedding=embed_model
)
return vectorstore
def create_chat_chain(vectorstore, GROQ_API_KEY: str):
"""
Create a chat chain
"""
with st.spinner(text="Creating chat model..."):
# create chain
chat_model = ChatGroq(
model_name='llama3-8b-8192',
api_key=GROQ_API_KEY
)
retriever = vectorstore.as_retriever(
search_kwargs={'k':3}
)
prompt = set_custom_prompt(custom_prompt_template=custom_prompt_template)
# create chat chain
combine_docs_chain = create_stuff_documents_chain(
llm=chat_model, prompt=prompt
)
qa = create_retrieval_chain(
retriever=retriever,
combine_docs_chain=combine_docs_chain
)
return qa