-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
95 lines (79 loc) · 3.12 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from transformers import pipeline
from langchain.llms import CTransformers #to get llm
from langchain.prompts import ChatPromptTemplate
import chromadb
from langchain.text_splitter import RecursiveCharacterTextSplitter#splitting text into chunks
from utils.VectorDBStorer import VectorDBStorer
from utils.UploadData import UploadData
import os
from flask import Flask, render_template, request, redirect, session
from werkzeug.utils import secure_filename
import secrets
# Create a new FastAPI app instance
app = Flask(__name__)
app.secret_key = secrets.token_hex(16)
# Global variable to store the collection object
collection = None
BUCKET = "hfdataset"
UPLOAD_FOLDER = "uploads"
@app.get("/generate")
def generate():
"""
Using the text2text-generation pipeline from `transformers`, generate text
from the given input text. The model used is `google/flan-t5-small`, which
can be found [here](<https://huggingface.co/google/flan-t5-small>).
"""
global collection
llm = CTransformers(
model = "TheBloke/Llama-2-7B-Chat-GGML",
model_type="llama",
temperature = 0.2
)
custom_prompt_template = """Use the following pieces of information to answer the user’s question.
Context: {context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(custom_prompt_template)
chain = prompt | llm
text = request.args.get('query')
# text = request.query_string.decode()
print(text)
results = collection.query(
query_texts=text,
n_results=1)
context = results['documents'][0][0]
question = text
# Use the pipeline to generate text from the given input text
output = chain.invoke({"context": context, "question": question})
print(output)
return render_template("query_page.html", response=output)
@app.route("/")
def home():
return render_template('index.html')
@app.route("/upload", methods=['POST'])
def upload():
global collection
if request.method == "POST":
f = request.files['file']
if f:
cwd = os.getcwd()
f.save(os.path.join(UPLOAD_FOLDER, secure_filename(f.filename)))
path = os.path.join(cwd, UPLOAD_FOLDER,f.filename)
data = UploadData()
# Upload the file to S3
success, upload_response = data.upload(path, BUCKET)
os.remove(path)
# Display success or error message
if success:
message = "File successfully uploaded to S3!"
client = chromadb.Client()
collection_name = "new_scientific_papers"
text_splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=20)
vector_db_storer = VectorDBStorer(client, collection_name, text_splitter, BUCKET)
collection = vector_db_storer.get_collection()
else:
message = f"Error uploading file to S3: {upload_response}"
return render_template("query_page.html", message=message)
return redirect("/")
if __name__ == '__main__':
app.run(debug=True, host="0.0.0.0", port=5000)