Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
fanzhidongyzby committed Aug 30, 2024
1 parent f17c32b commit 1b569c9
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 204 deletions.
95 changes: 62 additions & 33 deletions examples/rag/graph_rag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
import pytest

from dbgpt.configs.model_config import ROOT_PATH
from dbgpt.core import ModelMessage, HumanPromptTemplate, ModelRequest, Chunk
from dbgpt.core import Chunk, HumanPromptTemplate, ModelMessage, ModelRequest
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.rag.retriever import RetrieverStrategy
from dbgpt.storage.knowledge_graph.community_summary import (
CommunitySummaryKnowledgeGraph,
CommunitySummaryKnowledgeGraphConfig,
)
from dbgpt.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig,
Expand All @@ -18,40 +23,75 @@
pre-requirements:
* Set LLM config (url/sk) in `.env`.
* Install pytest utils: `pip install pytest pytest-asyncio`
* Config TuGraph following the format below.
```
GRAPH_STORE_TYPE=TuGraph
TUGRAPH_HOST=127.0.0.1
TUGRAPH_PORT=7687
TUGRAPH_USERNAME=admin
TUGRAPH_PASSWORD=73@TuGraph
```
Examples:
..code-block:: shell
pytest -s examples/rag/graph_rag_example.py
"""

llm_client = OpenAILLMClient()
model_name = "gpt-4o-mini"
rag_template = (
"Based on the following [Context] {context}, answer [Question] {question}."
)

file = "examples/test_files/graphrag-mini.md"
question = "What is TuGraph ?"

@pytest.mark.asyncio
async def test_naive_graph_rag():
await __run_graph_rag(
knowledge_file="examples/test_files/graphrag-mini.md",
chunk_strategy="CHUNK_BY_SIZE",
knowledge_graph=__create_naive_kg_connector(),
question="What's the relationship between TuGraph and DB-GPT ?",
)


@pytest.mark.asyncio
async def test_community_graph_rag():
await __run_graph_rag(
knowledge_file="examples/test_files/graphrag-mini.md",
chunk_strategy="CHUNK_BY_MARKDOWN_HEADER",
knowledge_graph=__create_community_kg_connector(),
question="What's the relationship between TuGraph and DB-GPT ?",
)


def _create_kg_connector():
def __create_naive_kg_connector():
"""Create knowledge graph connector."""
return BuiltinKnowledgeGraph(
config=BuiltinKnowledgeGraphConfig(
name="graph_rag_test",
name="naive_graph_rag_test",
embedding_fn=None,
llm_client=llm_client,
model_name=model_name,
graph_store_type='MemoryGraph'
graph_store_type="MemoryGraph",
),
)


async def chat_rag(chunk: Chunk) -> str:
template = HumanPromptTemplate.from_template(rag_template)
messages = template.format_messages(
context=chunk,
question=question
def __create_community_kg_connector():
"""Create community knowledge graph connector."""
return CommunitySummaryKnowledgeGraph(
config=CommunitySummaryKnowledgeGraphConfig(
name="community_graph_rag_test",
embedding_fn=DefaultEmbeddingFactory.openai(),
llm_client=llm_client,
model_name=model_name,
graph_store_type="TuGraphGraph",
),
)


async def ask_chunk(chunk: Chunk, question) -> str:
rag_template = (
"Based on the following [Context] {context}, " "answer [Question] {question}."
)
template = HumanPromptTemplate.from_template(rag_template)
messages = template.format_messages(context=chunk.content, question=question)
model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=model_name, messages=model_messages)
response = await llm_client.generate(request=request)
Expand All @@ -64,38 +104,27 @@ async def chat_rag(chunk: Chunk) -> str:
return response.text


@pytest.mark.asyncio
async def test_naive_graph_rag():
file_path = os.path.join(ROOT_PATH, file)
async def __run_graph_rag(knowledge_file, chunk_strategy, knowledge_graph, question):
file_path = os.path.join(ROOT_PATH, knowledge_file).format()
knowledge = KnowledgeFactory.from_file_path(file_path)
graph_store = _create_kg_connector()

try:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
chunk_parameters = ChunkParameters(chunk_strategy=chunk_strategy)

# get embedding assembler
assembler = await EmbeddingAssembler.aload_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
index_store=graph_store,
index_store=knowledge_graph,
retrieve_strategy=RetrieverStrategy.GRAPH,
)
await assembler.apersist()

# get embeddings retriever
retriever = assembler.as_retriever(1)
chunks = await retriever.aretrieve_with_scores(
question, score_threshold=0.3
)
chunks = await retriever.aretrieve_with_scores(question, score_threshold=0.3)

# chat
print(f"{await chat_rag(chunks[0])}")

except Exception as e:
graph_store.delete_vector_name("graph_rag_test")
raise e
print(f"{await ask_chunk(chunks[0], question)}")


@pytest.mark.asyncio
async def test_community_graph_rag():
pass
finally:
knowledge_graph.delete_vector_name(knowledge_graph.get_config().name)
70 changes: 0 additions & 70 deletions examples/rag/graph_rag_summary_example.py

This file was deleted.

97 changes: 0 additions & 97 deletions examples/test_files/graph_rag_mini.md

This file was deleted.

Loading

0 comments on commit 1b569c9

Please sign in to comment.