forked from RasaHQ/rasa-calm-demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load-data-to-qdrant.py
48 lines (42 loc) · 1.44 KB
/
load-data-to-qdrant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from datasets import load_dataset
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores.qdrant import Qdrant
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def dataset_to_documents(dataset):
documents = []
for i in range(len(dataset)):
documents.append(Document(
page_content=dataset[i]['question'],
metadata={
'type': 'faq',
'answer': dataset[i]['answers']['text'][0],
'id': dataset[i]['id'],
'title': dataset[i]['title'],
}))
return documents
def load_dataset_to_qdrant(dataset_name):
squad = load_dataset(dataset_name)
logger.info(f"✅ Dataset")
embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={"device": 'cpu'},
encode_kwargs={'normalize_embeddings': True},
)
logger.info(f"✅ Embeddings")
docs = dataset_to_documents(squad['train'])
return Qdrant.from_documents(
docs,
embeddings,
host="localhost",
prefer_grpc=True,
collection_name="squad",
)
if __name__ == "__main__":
qdrant = load_dataset_to_qdrant("rajpurkar/squad")
logger.info(f"✅ Qdrant")
result = qdrant.similarity_search_with_score("Who built Notre Dame?")
logger.info(f"Qdrant search result:")
logger.info(result)